From d09e76c8f951b02ad27daf162fab0a45a6ca6da4 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 14 May 2026 14:09:30 +0200 Subject: [PATCH] fix matmul rewriting/lowering fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages --- src/PIM/Common/Support/Diagnostics.hpp | 24 ++ .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 21 +- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 8 +- .../ONNXToSpatial/HostFoldability.cpp | 181 ++++++++ .../ONNXToSpatial/HostFoldability.hpp | 3 + .../Conversion/ONNXToSpatial/HostLegality.cpp | 13 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 60 ++- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 390 ++++++++++++------ .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 133 +++++- .../ONNXToSpatial/Patterns/Tensor/Reshape.cpp | 33 ++ .../Conversion/ONNXToSpatial/PrePatterns.cpp | 2 - src/PIM/Pass/PimCodegen/VerificationPass.cpp | 124 ++++-- 12 files changed, 766 insertions(+), 226 deletions(-) diff --git a/src/PIM/Common/Support/Diagnostics.hpp b/src/PIM/Common/Support/Diagnostics.hpp index 0e9e884..5d84111 100644 --- a/src/PIM/Common/Support/Diagnostics.hpp +++ b/src/PIM/Common/Support/Diagnostics.hpp @@ -7,10 +7,34 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" +#include #include namespace onnx_mlir::pim { +struct CappedDiagnosticReporter { + explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {} + + template + void report(mlir::Operation* op, EmitFn&& emit) { + numFailures++; + if (numFailures <= maxReportedFailures) + emit(op); + } + + void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const { + if (numFailures > maxReportedFailures) + op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " + << failureDescription; + } + + bool hasFailure() const { return numFailures != 0; } + +private: + int64_t maxReportedFailures; + int64_t numFailures = 0; +}; + /// Emits a consistent diagnostic for target paths that require static shapes. mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index b263a23..6f518eb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -100,18 +100,27 @@ DenseMap>> tileMatrix( return tiles; } -tensor::SplatOp -broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { +Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { auto oldType = cast(scalarToBroadcast.getType()); Type elementType = oldType.getElementType(); int64_t shape[2] = {1, length}; Type type = oldType.cloneWith(ArrayRef(shape), elementType); - auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); - SmallVector index(oldType.getRank(), zero); - auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult(); + auto buildBroadcast = [&](Value input) -> Value { + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + SmallVector index(oldType.getRank(), zero); + auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult(); + return tensor::SplatOp::create(rewriter, loc, type, elementValue); + }; - return tensor::SplatOp::create(rewriter, loc, type, elementValue); + if (isHostFoldableValue(scalarToBroadcast)) + return buildBroadcast(scalarToBroadcast); + + auto broadcastCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) { + spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input)); + }); + return broadcastCompute.getResult(0); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index b6d6182..ab03d75 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile, mlir::ConversionPatternRewriter& rewriter, mlir::Location& loc); -mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, - int64_t length, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); +mlir::Value broadcastToVector(mlir::Value scalarToBroadcast, + int64_t length, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp index 4c73c52..c42e186 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp @@ -1,8 +1,12 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) { return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; }); } +static bool hasConstantIndices(tensor::ExtractOp extractOp) { + return llvm::all_of(extractOp.getIndices(), + [](Value index) { return isa_and_nonnull(index.getDefiningOp()); }); +} + static bool isStaticTensorResult(Operation* op) { return llvm::all_of(op->getResultTypes(), [](Type type) { auto shapedType = dyn_cast(type); @@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) { }); } +static SmallVector computeRowMajorStrides(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) + strides[dim] = strides[dim + 1] * shape[dim + 1]; + return strides; +} + +static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { + auto tensorType = dyn_cast(denseAttr.getType()); + if (!tensorType) + return failure(); + + int64_t rank = tensorType.getRank(); + if (static_cast(perms.size()) != rank) + return failure(); + + llvm::SmallBitVector seen(rank); + SmallVector transposedShape; + transposedShape.reserve(rank); + for (int64_t perm : perms) { + if (perm < 0 || perm >= rank || seen.test(perm)) + return failure(); + seen.set(perm); + transposedShape.push_back(tensorType.getShape()[perm]); + } + + auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding()); + if (denseAttr.isSplat()) + return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); + + SmallVector originalValues(denseAttr.getValues()); + SmallVector transposedValues(originalValues.size()); + SmallVector originalStrides = computeRowMajorStrides(tensorType.getShape()); + SmallVector transposedStrides = computeRowMajorStrides(transposedShape); + SmallVector originalIndices(rank); + + for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { + int64_t remaining = static_cast(linearIndex); + for (int64_t dim = 0; dim < rank; ++dim) { + originalIndices[dim] = remaining / originalStrides[dim]; + remaining %= originalStrides[dim]; + } + + int64_t transposedLinearIndex = 0; + for (int64_t dim = 0; dim < rank; ++dim) + transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim]; + + transposedValues[transposedLinearIndex] = value; + } + + return DenseElementsAttr::get(transposedType, transposedValues); +} + +static FailureOr reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) { + auto sourceType = dyn_cast(denseAttr.getType()); + if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements()) + return failure(); + + if (denseAttr.isSplat()) + return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); + + SmallVector values(denseAttr.getValues()); + return DenseElementsAttr::get(resultType, values); +} + +static FailureOr extractSliceDenseElements(DenseElementsAttr denseAttr, + tensor::ExtractSliceOp extractSliceOp) { + auto sourceType = dyn_cast(denseAttr.getType()); + auto resultType = dyn_cast(extractSliceOp.getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + ArrayRef offsets = extractSliceOp.getStaticOffsets(); + ArrayRef sizes = extractSliceOp.getStaticSizes(); + ArrayRef strides = extractSliceOp.getStaticStrides(); + if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); }) + || llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); }) + || llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; })) + return failure(); + + if (denseAttr.isSplat()) + return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); + + SmallVector sourceValues(denseAttr.getValues()); + SmallVector sourceStrides = computeRowMajorStrides(sourceType.getShape()); + SmallVector resultStrides = computeRowMajorStrides(resultType.getShape()); + SmallVector resultValues; + resultValues.reserve(resultType.getNumElements()); + + for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) { + int64_t remaining = linearIndex; + int64_t sourceLinearIndex = 0; + for (int64_t dim = 0; dim < resultType.getRank(); ++dim) { + const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim]; + remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim]; + sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim]; + } + resultValues.push_back(sourceValues[sourceLinearIndex]); + } + + return DenseElementsAttr::get(resultType, resultValues); +} + +static DenseElementsAttr getDirectDenseConstantAttr(Value value) { + if (auto constantOp = value.getDefiningOp()) + return dyn_cast(constantOp.getValue()); + if (auto constantOp = value.getDefiningOp()) + return dyn_cast_or_null(constantOp.getValueAttr()); + return nullptr; +} + +static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl& visited) { + auto* definingOp = value.getDefiningOp(); + if (!definingOp || !visited.insert(definingOp).second) + return nullptr; + + // Rebuild dense attributes through view-only host-foldable chains so later + // lowering stages can still recognize grouped/sliced constants. + if (auto denseAttr = getDirectDenseConstantAttr(value)) + return denseAttr; + + if (auto transposeOp = dyn_cast(definingOp)) { + auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited); + if (!inputAttr) + return nullptr; + + SmallVector perm; + perm.reserve(transposeOp.getPermAttr().size()); + for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange()) + perm.push_back(attr.getInt()); + auto transposedAttr = transposeDenseElements(inputAttr, perm); + return succeeded(transposedAttr) ? *transposedAttr : nullptr; + } + + if (auto collapseShapeOp = dyn_cast(definingOp)) { + auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); + if (!inputAttr) + return nullptr; + auto reshapedAttr = reshapeDenseElements(inputAttr, cast(collapseShapeOp.getType())); + return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; + } + + if (auto expandShapeOp = dyn_cast(definingOp)) { + auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited); + if (!inputAttr) + return nullptr; + auto reshapedAttr = reshapeDenseElements(inputAttr, cast(expandShapeOp.getType())); + return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; + } + + if (auto extractSliceOp = dyn_cast(definingOp)) { + auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited); + if (!inputAttr) + return nullptr; + auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp); + return succeeded(slicedAttr) ? *slicedAttr : nullptr; + } + + return nullptr; +} + static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { if (!op || !visited.insert(op).second) return false; @@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl(op)) return true; + if (auto extractOp = dyn_cast(op)) + return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor()); + if (!isStaticTensorResult(op)) return false; @@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl(op)) return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource()); + if (auto splatOp = dyn_cast(op)) + return isHostFoldableValue(splatOp.getInput()); + if (auto extractRowsOp = dyn_cast(op)) return isHostFoldableValue(extractRowsOp.getInput()); @@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) { return isHostFoldableOpImpl(op, visited); } +DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) { + llvm::SmallPtrSet visited; + return getHostFoldableDenseElementsAttrImpl(value, visited); +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp index 0479987..3e3437c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp @@ -1,5 +1,6 @@ #pragma once +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" @@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value); bool isHostFoldableOp(mlir::Operation* op); +mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp index b16d513..252631a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -11,7 +12,7 @@ using namespace mlir; namespace onnx_mlir { LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { - bool hasFailure = false; + pim::CappedDiagnosticReporter diagnostics; for (Operation& op : funcOp.getFunctionBody().front()) { if (isa(&op)) @@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { if (isHostFoldableOp(&op)) continue; - op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); - hasFailure = true; + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside " + "spat.compute"); + }); } - return success(!hasFailure); + diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures"); + + return success(!diagnostics.hasFailure()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 651a4a7..c004bea 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -6,13 +6,14 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "Common/Common.hpp" #include "Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" @@ -84,6 +85,30 @@ static void populateEmptyFunction(func::FuncOp funcOp) { returnOp.setOperand(index, computeResult); } +static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + Block& entryBlock = funcOp.getFunctionBody().front(); + + for (Operation& op : llvm::make_early_inc_range(entryBlock)) { + auto transposeOp = dyn_cast(&op); + if (!transposeOp || isHostFoldableOp(transposeOp)) + continue; + + // Transpose stays globally legal because constant/view-only cases are + // allowed on the host. Any residual runtime transpose must be sunk into + // spat.compute before the host legality check. + auto resultType = transposeOp.getResult().getType(); + rewriter.setInsertionPoint(transposeOp); + auto computeOp = createSpatCompute<1>( + rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) { + Value transposed = + ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr()); + spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed); + }); + rewriter.replaceOp(transposeOp, computeOp.getResult(0)); + } +} + void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); @@ -94,7 +119,7 @@ void ONNXToSpatialPass::runOnOperation() { tensor::TensorDialect, arith::ArithDialect, scf::SCFDialect>(); - preTarget.addIllegalOp(); + preTarget.addIllegalOp(); RewritePatternSet prePatterns(ctx); populatePrePatterns(prePatterns, ctx); @@ -111,6 +136,21 @@ void ONNXToSpatialPass::runOnOperation() { return; } + RewritePatternSet matmulPatterns(ctx); + populateMatMulRewritePatterns(matmulPatterns, ctx); + walkAndApplyPatterns(moduleOp, std::move(matmulPatterns)); + + bool hasUnloweredMatMul = false; + moduleOp.walk([&](ONNXMatMulOp matmulOp) { + hasUnloweredMatMul = true; + matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion"); + }); + if (hasUnloweredMatMul) { + moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion"); + signalPassFailure(); + return; + } + ConversionTarget target(*ctx); target.addLegalDialectgetFunctionBody().front().getOperations()) - if (isa(op)) - computeOpsCount++; - - if (computeOpsCount > coresCount) { - entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count (" - << coresCount << ")"; - signalPassFailure(); - return; - } - } - PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) @@ -201,6 +227,8 @@ void ONNXToSpatialPass::runOnOperation() { return; } + wrapTopLevelRuntimeTransposes(*entryFunc); + if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { moduleOp.emitError("ONNX-to-Spatial host legality verification failed"); signalPassFailure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 4c0633e..0a855c4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -11,6 +11,7 @@ #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; -static DenseElementsAttr getDenseConstantAttr(Value value) { - if (auto constantOp = value.getDefiningOp()) - return dyn_cast(constantOp.getValue()); - - if (auto constantOp = value.getDefiningOp()) - return dyn_cast_or_null(constantOp.getValueAttr()); - - return nullptr; -} - static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast(arr[idx]).getInt(); } static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { @@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows, return collectComputeOp.getResult(0); } -} // namespace - -LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, - ONNXConvOpAdaptor convOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location loc = convOp.getLoc(); - Value x = convOpAdaptor.getX(); - Value w = convOpAdaptor.getW(); - Value b = convOpAdaptor.getB(); - - auto xType = cast(x.getType()); - auto wType = cast(w.getType()); - auto outType = cast(convOp.getY().getType()); - - if (!xType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); - return failure(); - } - if (!wType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); - return failure(); - } - if (!outType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); - return failure(); - } - if (xType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4}); - return failure(); - } - if (wType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4}); - return failure(); - } - if (outType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4}); - return failure(); - } - if (convOp.getGroup() != 1) { - convOp.emitOpError("only group=1 convolution is supported for Spatial lowering"); - return failure(); - } - +static Value lowerSingleConvGroup(Value x, + Value w, + Value b, + RankedTensorType xType, + RankedTensorType wType, + RankedTensorType outType, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + ConversionPatternRewriter& rewriter, + Location loc) { const int64_t batchSize = xType.getDimSize(0); const int64_t numChannelsIn = xType.getDimSize(1); const int64_t xHeight = xType.getDimSize(2); @@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, const int64_t outHeight = outType.getDimSize(2); const int64_t outWidth = outType.getDimSize(3); - // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) - const auto stridesAttr = convOp.getStrides(); - const auto dilationsAttr = convOp.getDilations(); - const auto padsAttr = convOp.getPads(); - - if (stridesAttr && stridesAttr->size() != 2) { - convOp.emitOpError("requires exactly two stride values for Spatial lowering"); - return failure(); - } - if (dilationsAttr && dilationsAttr->size() != 2) { - convOp.emitOpError("requires exactly two dilation values for Spatial lowering"); - return failure(); - } - if (padsAttr && padsAttr->size() != 4) { - convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering"); - return failure(); - } - - const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; - const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; - const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; - const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; - - int64_t padHeightBegin = 0; - int64_t padHeightEnd = 0; - int64_t padWidthBegin = 0; - int64_t padWidthEnd = 0; - - if (padsAttr) { - padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); - padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); - padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); - padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); - } - else { - // Compute padding from auto_pad attribute - const auto autoPad = convOp.getAutoPad(); - if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; - const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; - const int64_t totalPadH = - std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); - const int64_t totalPadW = - std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); - - if (autoPad == "SAME_UPPER") { - padHeightBegin = totalPadH / 2; - padHeightEnd = totalPadH - padHeightBegin; - padWidthBegin = totalPadW / 2; - padWidthEnd = totalPadW - padWidthBegin; - } - else { // SAME_LOWER - padHeightEnd = totalPadH / 2; - padHeightBegin = totalPadH - padHeightEnd; - padWidthEnd = totalPadW / 2; - padWidthBegin = totalPadW - padWidthEnd; - } - } - else if (autoPad != "NOTSET" && autoPad != "VALID") { - convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; - return failure(); - } - // "NOTSET" or "VALID" -> all pads stay 0 - } - // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): // A (im2col): [numPatches, patchSize] -- one row per output spatial position // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns @@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t wMaxDim = std::max(patchSize, numChannelsOut); const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); - auto wDenseAttr = getDenseConstantAttr(w); + auto wDenseAttr = getHostFoldableDenseElementsAttr(w); // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] @@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, DenseElementsAttr biasDenseAttr; if (hasB) { gemmBias = b; - biasDenseAttr = getDenseConstantAttr(b); + biasDenseAttr = getHostFoldableDenseElementsAttr(b); biasMatrix = expandBiasIfNeeded(b, rewriter, loc); } const bool canPackWeightsAsConstants = static_cast(wDenseAttr); @@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, rewriter.getBoolAttr(false)) .getY(); - rewriter.replaceOp(convOp, - createCollectedConvOutput(ValueRange {gemmRows}, - convOp.getType(), - gemmOutType, - nhwcType, - outType, - numPatches, - numChannelsOut, - effectiveMaxParallelPixels, - rewriter, - loc)); + return createCollectedConvOutput(ValueRange {gemmRows}, + outType, + gemmOutType, + nhwcType, + outType, + numPatches, + numChannelsOut, + effectiveMaxParallelPixels, + rewriter, + loc); +} + +} // namespace + +LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = convOp.getLoc(); + Value x = convOpAdaptor.getX(); + Value w = convOpAdaptor.getW(); + Value b = convOpAdaptor.getB(); + + auto xType = cast(x.getType()); + auto wType = cast(w.getType()); + auto outType = cast(convOp.getY().getType()); + + if (!xType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); + return failure(); + } + if (!wType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); + return failure(); + } + if (!outType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); + return failure(); + } + if (xType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4}); + return failure(); + } + if (wType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4}); + return failure(); + } + if (outType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4}); + return failure(); + } + if (convOp.getGroup() < 1) { + convOp.emitOpError("requires group >= 1 for Spatial lowering"); + return failure(); + } + + const int64_t batchSize = xType.getDimSize(0); + const int64_t numChannelsIn = xType.getDimSize(1); + const int64_t xHeight = xType.getDimSize(2); + const int64_t xWidth = xType.getDimSize(3); + const int64_t numChannelsOut = wType.getDimSize(0); + const int64_t wHeight = wType.getDimSize(2); + const int64_t wWidth = wType.getDimSize(3); + const int64_t outHeight = outType.getDimSize(2); + const int64_t outWidth = outType.getDimSize(3); + const int64_t group = convOp.getGroup(); + const bool hasB = !isa(b.getDefiningOp()); + + if (numChannelsIn % group != 0) { + convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group + << " for Spatial lowering"; + return failure(); + } + if (numChannelsOut % group != 0) { + convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group + << " for Spatial lowering"; + return failure(); + } + + const int64_t numChannelsInPerGroup = numChannelsIn / group; + const int64_t numChannelsOutPerGroup = numChannelsOut / group; + if (wType.getDimSize(1) != numChannelsInPerGroup) { + convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1) + << " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering"; + return failure(); + } + if (wType.getDimSize(0) != numChannelsOut) { + convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels " + << numChannelsOut << " for Spatial lowering"; + return failure(); + } + + // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) + const auto stridesAttr = convOp.getStrides(); + const auto dilationsAttr = convOp.getDilations(); + const auto padsAttr = convOp.getPads(); + + if (stridesAttr && stridesAttr->size() != 2) { + convOp.emitOpError("requires exactly two stride values for Spatial lowering"); + return failure(); + } + if (dilationsAttr && dilationsAttr->size() != 2) { + convOp.emitOpError("requires exactly two dilation values for Spatial lowering"); + return failure(); + } + if (padsAttr && padsAttr->size() != 4) { + convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering"); + return failure(); + } + + const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; + const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; + const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; + const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; + + int64_t padHeightBegin = 0; + int64_t padHeightEnd = 0; + int64_t padWidthBegin = 0; + int64_t padWidthEnd = 0; + + if (padsAttr) { + padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); + padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); + padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); + padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); + } + else { + // Compute padding from auto_pad attribute + const auto autoPad = convOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; + const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; + const int64_t totalPadH = + std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); + const int64_t totalPadW = + std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); + + if (autoPad == "SAME_UPPER") { + padHeightBegin = totalPadH / 2; + padHeightEnd = totalPadH - padHeightBegin; + padWidthBegin = totalPadW / 2; + padWidthEnd = totalPadW - padWidthBegin; + } + else { // SAME_LOWER + padHeightEnd = totalPadH / 2; + padHeightBegin = totalPadH - padHeightEnd; + padWidthEnd = totalPadW / 2; + padWidthBegin = totalPadW - padWidthEnd; + } + } + else if (autoPad != "NOTSET" && autoPad != "VALID") { + convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; + return failure(); + } + // "NOTSET" or "VALID" -> all pads stay 0 + } + + if (group == 1) { + rewriter.replaceOp(convOp, + lowerSingleConvGroup(x, + w, + b, + xType, + wType, + outType, + padHeightBegin, + padHeightEnd, + padWidthBegin, + padWidthEnd, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + rewriter, + loc)); + return success(); + } + + SmallVector xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc); + SmallVector wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc); + SmallVector bSlices; + if (hasB) { + auto biasType = cast(b.getType()); + int64_t biasAxis = -1; + if (biasType.getRank() == 1) + biasAxis = 0; + else if (biasType.getRank() == 2) + biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1; + else { + convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank " + << biasType.getRank(); + return failure(); + } + bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc); + } + + if (xSlices.size() != static_cast(group) || wSlices.size() != static_cast(group) + || (hasB && bSlices.size() != static_cast(group))) { + convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering"); + return failure(); + } + + SmallVector groupResults; + groupResults.reserve(group); + auto groupOutType = + RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType()); + Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + for (int64_t groupId = 0; groupId < group; groupId++) { + Value groupX = xSlices[groupId]; + Value groupW = wSlices[groupId]; + Value groupB = hasB ? bSlices[groupId] : noBias; + groupResults.push_back(lowerSingleConvGroup(groupX, + groupW, + groupB, + cast(groupX.getType()), + cast(groupW.getType()), + groupOutType, + padHeightBegin, + padHeightEnd, + padWidthBegin, + padWidthEnd, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + rewriter, + loc)); + } + + Value result; + if (llvm::all_of(groupResults, isHostFoldableValue)) { + result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); + } + else { + auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args)); + }); + result = concatCompute.getResult(0); + } + + rewriter.replaceOp(convOp, result); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 5dd9a2d..dd1d227 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -2,8 +2,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include +#include + #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" @@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef shape) { return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); } +static int64_t getStaticShapeElementCount(ArrayRef shape) { + return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); +} + +static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, + ArrayRef rhsBatchShape) { + if (lhsBatchShape.empty()) + return SmallVector(rhsBatchShape.begin(), rhsBatchShape.end()); + if (rhsBatchShape.empty()) + return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); + if (!llvm::equal(lhsBatchShape, rhsBatchShape)) + return failure(); + return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); +} + +static Value collapseBatchDims(Value value, + int64_t batchSize, + int64_t rows, + int64_t cols, + PatternRewriter& rewriter, + Location loc) { + auto type = cast(value.getType()); + if (type.getRank() == 2 || type.getRank() == 3) + return value; + + auto collapsedType = + RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); + SmallVector reassociation = { + ReassociationIndices {}, + ReassociationIndices {static_cast(type.getRank() - 2)}, + ReassociationIndices {static_cast(type.getRank() - 1)} + }; + for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) + reassociation.front().push_back(dim); + + auto buildCollapsed = [&](Value input) -> Value { + return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); + }; + + if (isHostFoldableValue(value)) + return buildCollapsed(value); + + auto collapseCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) { + spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input)); + }); + return collapseCompute.getResult(0); +} + +static Value expandBatchDims(Value value, + RankedTensorType outputType, + size_t batchRank, + PatternRewriter& rewriter, + Location loc) { + if (cast(value.getType()) == outputType) + return value; + + SmallVector reassociation = { + ReassociationIndices {}, + ReassociationIndices {static_cast(batchRank)}, + ReassociationIndices {static_cast(batchRank + 1)} + }; + for (size_t dim = 0; dim < batchRank; ++dim) + reassociation.front().push_back(static_cast(dim)); + + auto expandCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { + Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation); + spatial::SpatYieldOp::create(rewriter, loc, expanded); + }); + return expandCompute.getResult(0); +} + static Value extractBatchMatrix(Value value, int64_t batchIndex, int64_t batchSize, @@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value, static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); + RankedTensorType transposedType; + SmallVector perm; if (type.getRank() == 2) { - auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); - return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0})); + transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); + perm = {1, 0}; + } + else { + transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); + perm = {0, 2, 1}; } - auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); - return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1})); + auto buildTranspose = [&](Value input) -> Value { + return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + }; + + if (isHostFoldableValue(value)) + return buildTranspose(value); + + auto transposeCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) { + spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input)); + }); + return transposeCompute.getResult(0); } static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { @@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern { if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); - if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3) - || (outType.getRank() != 2 && outType.getRank() != 3)) + if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) return failure(); if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) || !haveStaticPositiveShape(outType.getShape())) return failure(); - const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1; - const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1; - const int64_t batch = std::max(lhsBatch, rhsBatch); - - if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch)) + SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); + SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); + auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); + if (failed(batchShape)) return failure(); + const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); + const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); + const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); - const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0); - const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1); - const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0); - const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1); + const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); + const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); + const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); + const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); if (k != rhsK) return failure(); @@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern { return failure(); } else { - if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n) + SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); + if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m + || outType.getDimSize(outType.getRank() - 1) != n) return failure(); } Location loc = matmulOp.getLoc(); bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); - Value lhs = matmulOp.getA(); - Value rhs = matmulOp.getB(); + Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc); + Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc); int64_t lhsBatchForGemm = lhsBatch; int64_t rhsBatchForGemm = rhsBatch; int64_t gemmM = m; @@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern { } Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); + result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc); rewriter.replaceOp(matmulOp, result); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index d670838..f44c524 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef sourceShape, return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); } +static SmallVector getCollapseTo1DReassociation(size_t rank) { + SmallVector reassociation(1); + reassociation.front().reserve(rank); + for (size_t dim = 0; dim < rank; ++dim) + reassociation.front().push_back(static_cast(dim)); + return reassociation; +} + +static SmallVector getExpandFrom1DReassociation(size_t rank) { + SmallVector reassociation(1); + reassociation.front().reserve(rank); + for (size_t dim = 0; dim < rank; ++dim) + reassociation.front().push_back(static_cast(dim)); + return reassociation; +} + struct Reshape : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern { return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); }); + if (sourceType.getNumElements() != resultType.getNumElements()) + return failure(); + + return replaceWithReshape([&](Value data) -> Value { + Value reshaped = data; + if (sourceType.getRank() != 1) { + auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType()); + reshaped = tensor::CollapseShapeOp::create( + rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank())); + } + if (resultType.getRank() == 1) + return reshaped; + return tensor::ExpandShapeOp::create( + rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank())) + .getResult(); + }); + return failure(); } }; diff --git a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp index 783733e..bf524d2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp @@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); - patterns.add(ctx); patterns.add(ctx); - populateMatMulRewritePatterns(patterns, ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 8c8d4c6..d5b3b43 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -7,6 +7,7 @@ #include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -152,14 +153,15 @@ struct VerificationPass : PassWrapper> void runOnOperation() override { ModuleOp moduleOp = getOperation(); - bool hasFailure = false; + pim::CappedDiagnosticReporter diagnostics; moduleOp.walk([&](Operation* op) { if (op->getDialect()->getNamespace() != "spat") return; - op->emitError("illegal Spatial operation reached PIM codegen verification"); - hasFailure = true; + diagnostics.report(op, [](Operation* illegalOp) { + illegalOp->emitError("illegal Spatial operation reached PIM codegen verification"); + }); }); for (func::FuncOp funcOp : moduleOp.getOps()) { @@ -168,36 +170,36 @@ struct VerificationPass : PassWrapper> for (Operation& op : funcOp.getBody().front().getOperations()) { if (auto coreOp = dyn_cast(&op)) { - if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp))) - hasFailure = true; + (void) verifyCoreWeights(moduleOp, coreOp, diagnostics); + (void) verifyCoreOperands(coreOp, diagnostics); continue; } if (auto coreBatchOp = dyn_cast(&op)) { - if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp))) - hasFailure = true; + (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); + (void) verifyCoreOperands(coreBatchOp, diagnostics); continue; } if (auto returnOp = dyn_cast(&op)) { - if (failed(verifyReturnOp(returnOp))) - hasFailure = true; + (void) verifyReturnOp(returnOp, diagnostics); continue; } if (!isAddressOnlyHostOp(&op)) { - op.emitOpError("illegal host-side runtime op remains after PIM bufferization; " - "fold it to constants or lower it into pim.core"); - hasFailure = true; + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; " + "fold it to constants or lower it into pim.core"); + }); continue; } - if (failed(verifyAddressOnlyHostOp(&op))) - hasFailure = true; + (void) verifyAddressOnlyHostOp(&op, diagnostics); } } - if (hasFailure) { + if (diagnostics.hasFailure()) { + diagnostics.emitSuppressedSummary(moduleOp, "verification failures"); moduleOp.emitError("PIM codegen verification failed; see diagnostics above"); signalPassFailure(); } @@ -205,14 +207,19 @@ struct VerificationPass : PassWrapper> private: template - static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) { + static LogicalResult + verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) { bool hasFailure = false; - for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { + for (auto it : llvm::enumerate(coreOp.getWeights())) { + size_t weightIndex = it.index(); + Value weight = it.value(); auto getGlobalOp = weight.template getDefiningOp(); if (!getGlobalOp && !isConstantGlobalView(weight)) { - coreOp.emitOpError() << "weight #" << weightIndex - << " must be materialized as a constant memref.global or a static view of one before JSON " - "codegen"; + diagnostics.report(coreOp.getOperation(), [&](Operation*) { + coreOp.emitOpError() << "weight #" << weightIndex + << " must be materialized as a constant memref.global or a static view of one before " + "JSON codegen"; + }); hasFailure = true; continue; } @@ -222,14 +229,18 @@ private: auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) { - coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; + diagnostics.report(coreOp.getOperation(), [&](Operation*) { + coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; + }); hasFailure = true; continue; } if (!globalOp.getConstant() || !globalOp.getInitialValue()) { - coreOp.emitOpError() << "weight #" << weightIndex - << " must come from a constant memref.global with an initial value"; + diagnostics.report(coreOp.getOperation(), [&](Operation*) { + coreOp.emitOpError() << "weight #" << weightIndex + << " must come from a constant memref.global with an initial value"; + }); hasFailure = true; } } @@ -237,11 +248,15 @@ private: return success(!hasFailure); } - static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { + static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) { bool hasFailure = false; - for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { + for (auto it : llvm::enumerate(returnOp.getOperands())) { + size_t resultIndex = it.index(); + Value operand = it.value(); if (!isCodegenAddressableValue(operand)) { - returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage"; + diagnostics.report(returnOp.getOperation(), [&](Operation*) { + returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage"; + }); hasFailure = true; } } @@ -249,38 +264,50 @@ private: } template - static LogicalResult verifyCoreOperands(CoreOpTy coreOp) { + static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) { return walkPimCoreBlock( - coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { + coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { bool hasFailure = false; if (!isSupportedCoreInstructionOp(&op)) { - op.emitOpError("unsupported executable op reached PIM codegen verification"); + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("unsupported executable op reached PIM codegen verification"); + }); hasFailure = true; } - for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { + for (auto it : llvm::enumerate(op.getOperands())) { + size_t operandIndex = it.index(); + Value operand = it.value(); if (!isa(operand.getType())) continue; auto resolvedAddress = resolveContiguousAddress(operand, knowledge); if (failed(resolvedAddress)) { - op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << "operand #" << operandIndex + << " is not backed by contiguous addressable storage"; + }); hasFailure = true; continue; } if (isExplicitHostOperand(&op, operandIndex)) { if (!isCodegenAddressableValue(operand, knowledge)) { - op.emitOpError() << "host operand #" << operandIndex - << " is not backed by contiguous addressable storage"; + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << "host operand #" << operandIndex + << " is not backed by contiguous addressable storage"; + }); hasFailure = true; } continue; } if (!isa(resolvedAddress->base.getDefiningOp())) { - op.emitOpError() << "operand #" << operandIndex - << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; + diagnostics.report(&op, [&](Operation* illegalOp) { + illegalOp->emitOpError() << "operand #" << operandIndex + << " must be backed by device-local memory; materialize host values with " + "pim.memcp_hd"; + }); hasFailure = true; } } @@ -288,18 +315,20 @@ private: }); } - static LogicalResult verifyAddressOnlyHostOp(Operation* op) { + static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) { if (auto subviewOp = dyn_cast(op)) - return verifyAddressOnlyBase(op, subviewOp.getSource()); + return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics); if (auto castOp = dyn_cast(op)) - return verifyAddressOnlySource(op, castOp.getSource()); + return verifyAddressOnlySource(op, castOp.getSource(), diagnostics); if (auto collapseOp = dyn_cast(op)) - return verifyAddressOnlySource(op, collapseOp.getSrc()); + return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics); if (auto expandOp = dyn_cast(op)) - return verifyAddressOnlySource(op, expandOp.getSrc()); + return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics); if (auto copyOp = dyn_cast(op)) { if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) { - op->emitOpError("depends on a value that is not backed by addressable storage"); + diagnostics.report(op, [](Operation* illegalOp) { + illegalOp->emitOpError("depends on a value that is not backed by addressable storage"); + }); return failure(); } return success(); @@ -307,19 +336,24 @@ private: return success(); } - static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { + static LogicalResult + verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) { if (isCodegenAddressableValue(source)) return success(); - op->emitOpError("depends on a value that is not backed by contiguous addressable storage"); + diagnostics.report(op, [](Operation* illegalOp) { + illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage"); + }); return failure(); } - static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) { + static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) { if (isBaseAddressableValue(source)) return success(); - op->emitOpError("depends on a value that is not backed by addressable storage"); + diagnostics.report(op, [](Operation* illegalOp) { + illegalOp->emitOpError("depends on a value that is not backed by addressable storage"); + }); return failure(); } };