410 lines
18 KiB
C++
410 lines
18 KiB
C++
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
|
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
|
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
|
|
|
struct RowStripPhysicalValue {
|
|
Value physicalValue;
|
|
RankedTensorType logicalType;
|
|
SmallVector<int64_t, 16> fragmentOffsets;
|
|
SmallVector<int64_t, 16> fragmentSizes;
|
|
std::string indexMap;
|
|
};
|
|
|
|
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
|
|
Value value) {
|
|
auto it = rowStripValues.find(value);
|
|
if (it == rowStripValues.end())
|
|
return failure();
|
|
return it->second;
|
|
}
|
|
|
|
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
|
|
Value physicalValue) {
|
|
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
|
if (!logicalType)
|
|
return blueprint.emitOpError("requires ranked logical output type"), failure();
|
|
RowStripPhysicalValue value;
|
|
value.physicalValue = physicalValue;
|
|
value.logicalType = logicalType;
|
|
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
|
|
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
|
|
value.indexMap = blueprint.getIndexMap().str();
|
|
return value;
|
|
}
|
|
|
|
static FailureOr<Value>
|
|
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
|
|
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
|
|
auto computeOp =
|
|
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
|
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
|
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static FailureOr<Value>
|
|
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
|
|
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
|
|
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
|
|
return failure();
|
|
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
|
|
return failure();
|
|
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
|
|
return failure();
|
|
|
|
const int64_t rank = rowStripValue.logicalType.getRank();
|
|
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
|
|
const int64_t packedWidth = packedType.getDimSize(1);
|
|
const int64_t packedChannels = packedType.getDimSize(2);
|
|
if (fragmentCount != packedType.getDimSize(0))
|
|
return failure();
|
|
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
|
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
|
|
return failure();
|
|
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
|
|
return failure();
|
|
}
|
|
|
|
auto packedSliceType =
|
|
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
|
auto expandedType =
|
|
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
|
auto logicalFragmentType =
|
|
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {rowStripValue.logicalType},
|
|
fragmentCount,
|
|
{},
|
|
ValueRange {rowStripValue.physicalValue},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> packedSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
|
|
Value packedSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
|
|
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
expandedType,
|
|
packedSlice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2},
|
|
{3}
|
|
});
|
|
Value transposeInit =
|
|
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
|
|
Value logicalFragment =
|
|
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
|
|
.getResult()[0];
|
|
|
|
SmallVector<OpFoldResult> logicalOffsets {
|
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(packedChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(packedWidth)};
|
|
createParallelInsertSliceIntoBatchOutput(rewriter,
|
|
loc,
|
|
logicalFragment,
|
|
args.outputs.front(),
|
|
logicalOffsets,
|
|
logicalSizes,
|
|
getUnitStrides(rewriter, 4));
|
|
return success();
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return batchOp->getResult(0);
|
|
}
|
|
|
|
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
|
|
|
|
StringRef getArgument() const override { return "lower-spatial-plans"; }
|
|
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
|
|
|
|
void runOnOperation() override {
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
PatternRewriter rewriter(ctx);
|
|
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
|
|
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
|
|
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
|
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
|
return true;
|
|
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
|
|
signalPassFailure();
|
|
return false;
|
|
};
|
|
|
|
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
|
|
return;
|
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
|
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
|
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
|
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
|
});
|
|
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
|
|
rewriter.setInsertionPoint(planOp);
|
|
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
|
planOp,
|
|
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
|
|
/*emitRowStripLayout=*/true,
|
|
rewriter);
|
|
if (failed(lowered)) {
|
|
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
|
|
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
|
|
if (failed(rowStripValue)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rowStripValues[blueprint.getResult()] = *rowStripValue;
|
|
eraseAfterLowering.insert(planOp);
|
|
eraseAfterLowering.insert(blueprint);
|
|
continue;
|
|
}
|
|
rewriter.setInsertionPoint(planOp);
|
|
FailureOr<Value> lowered =
|
|
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
|
|
if (failed(lowered)) {
|
|
planOp.emitOpError("failed to lower selected Spatial Conv plan");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rewriter.replaceOp(planOp, *lowered);
|
|
continue;
|
|
}
|
|
|
|
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
|
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
|
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
|
});
|
|
if (outputBlueprint == planOp.getResult().getUsers().end()) {
|
|
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
|
|
rewriter.setInsertionPoint(planOp);
|
|
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
|
|
if (failed(lowered)) {
|
|
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
|
|
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
|
|
if (failed(output)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rowStripValues[blueprint.getResult()] = *output;
|
|
eraseAfterLowering.insert(planOp);
|
|
eraseAfterLowering.insert(blueprint);
|
|
continue;
|
|
}
|
|
|
|
rewriter.setInsertionPoint(planOp);
|
|
auto computeOp = createSpatCompute<1>(
|
|
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
|
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
|
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
|
});
|
|
rewriter.replaceOp(planOp, computeOp.getResults());
|
|
continue;
|
|
}
|
|
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
|
|
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
|
|
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
|
|
rewriter.replaceOp(materializeOp, materializeOp.getInput());
|
|
continue;
|
|
}
|
|
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|
|
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
|
|
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
|
if (failed(rowStripValue)) {
|
|
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rewriter.setInsertionPoint(materializeOp);
|
|
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
|
|
if (failed(dense)) {
|
|
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rewriter.replaceOp(materializeOp, *dense);
|
|
continue;
|
|
}
|
|
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
|
|
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
|
|
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
|
|
continue;
|
|
}
|
|
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
|
|
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
if (!eraseAfterLowering.contains(blueprintOp)) {
|
|
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
bool erasedAny = true;
|
|
while (erasedAny) {
|
|
erasedAny = false;
|
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
|
if (!eraseAfterLowering.contains(&op))
|
|
continue;
|
|
if (!op.use_empty())
|
|
continue;
|
|
eraseAfterLowering.erase(&op);
|
|
rewriter.eraseOp(&op);
|
|
erasedAny = true;
|
|
}
|
|
}
|
|
if (!eraseAfterLowering.empty()) {
|
|
for (Operation& op : funcOp.getBody().front())
|
|
if (eraseAfterLowering.contains(&op))
|
|
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
ConversionTarget helperTarget(*ctx);
|
|
helperTarget.addLegalDialect<spatial::SpatialDialect,
|
|
tensor::TensorDialect,
|
|
linalg::LinalgDialect,
|
|
affine::AffineDialect,
|
|
arith::ArithDialect,
|
|
scf::SCFDialect,
|
|
func::FuncDialect>();
|
|
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
|
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
|
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
|
|
|
RewritePatternSet helperPatterns(ctx);
|
|
populateGemmPatterns(helperPatterns, ctx);
|
|
populateTransposePatterns(helperPatterns, ctx);
|
|
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
|
|
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
FrozenRewritePatternSet nestedHelperPatterns([&] {
|
|
RewritePatternSet patterns(ctx);
|
|
populateGemmPatterns(patterns, ctx);
|
|
populateTransposePatterns(patterns, ctx);
|
|
return patterns;
|
|
}());
|
|
ConversionTarget nestedHelperTarget(*ctx);
|
|
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
|
|
tensor::TensorDialect,
|
|
linalg::LinalgDialect,
|
|
affine::AffineDialect,
|
|
arith::ArithDialect,
|
|
scf::SCFDialect,
|
|
func::FuncDialect>();
|
|
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
|
SmallVector<Operation*> computeLikeOps;
|
|
funcOp.walk([&](Operation* op) {
|
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
|
|
computeLikeOps.push_back(op);
|
|
});
|
|
for (Operation* op : computeLikeOps) {
|
|
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
|
|
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
if (!verifyLogicalPhase("after nested helper conversions"))
|
|
return;
|
|
bool hasIllegalOps = false;
|
|
moduleOp.walk([&](Operation* op) {
|
|
if (isa<ONNXEntryPointOp>(op))
|
|
return;
|
|
if (isa<spatial::SpatConv2DPlanOp,
|
|
spatial::SpatReluPlanOp,
|
|
spatial::SpatBlueprintOp,
|
|
spatial::SpatMaterializeLayoutOp>(op)
|
|
|| op->getDialect()->getNamespace() == "onnx") {
|
|
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
|
hasIllegalOps = true;
|
|
}
|
|
});
|
|
if (hasIllegalOps)
|
|
signalPassFailure();
|
|
else
|
|
dumpModule(moduleOp, "spatial1_premerge");
|
|
|
|
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
|
|
return;
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
|
|
|
|
} // namespace onnx_mlir
|