fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled

fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
This commit is contained in:
NiccoloN
2026-05-14 14:09:30 +02:00
parent c5e608fa5b
commit d09e76c8f9
12 changed files with 766 additions and 226 deletions
+24
View File
@@ -7,10 +7,34 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error>
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
return tiles;
}
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
if (isHostFoldableValue(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
}
} // namespace onnx_mlir
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -1,8 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
});
}
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
if (!isStaticTensorResult(op))
return false;
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
return isHostFoldableOpImpl(op, visited);
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -11,7 +12,7 @@ using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
}
return success(!hasFailure);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -6,13 +6,14 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
@@ -84,6 +85,30 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -94,7 +119,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXMatMulOp, ONNXFlattenOp>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
@@ -111,6 +136,21 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet matmulPatterns(ctx);
populateMatMulRewritePatterns(matmulPatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
bool hasUnloweredMatMul = false;
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
hasUnloweredMatMul = true;
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
});
if (hasUnloweredMatMul) {
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -161,20 +201,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
if (coresCount != -1) {
int computeOpsCount = 0;
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
}
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
@@ -201,6 +227,8 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
@@ -11,6 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
return collectComputeOp.getResult(0);
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
static Value lowerSingleConvGroup(Value x,
Value w,
Value b,
RankedTensorType xType,
RankedTensorType wType,
RankedTensorType outType,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b);
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
rewriter.getBoolAttr(false))
.getY();
rewriter.replaceOp(convOp,
createCollectedConvOutput(ValueRange {gemmRows},
convOp.getType(),
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc));
return createCollectedConvOutput(ValueRange {gemmRows},
outType,
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc);
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() < 1) {
convOp.emitOpError("requires group >= 1 for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
const int64_t group = convOp.getGroup();
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
if (numChannelsIn % group != 0) {
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
if (numChannelsOut % group != 0) {
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
const int64_t numChannelsInPerGroup = numChannelsIn / group;
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
if (wType.getDimSize(1) != numChannelsInPerGroup) {
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
return failure();
}
if (wType.getDimSize(0) != numChannelsOut) {
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
<< numChannelsOut << " for Spatial lowering";
return failure();
}
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
if (group == 1) {
rewriter.replaceOp(convOp,
lowerSingleConvGroup(x,
w,
b,
xType,
wType,
outType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
return success();
}
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
SmallVector<Value> bSlices;
if (hasB) {
auto biasType = cast<RankedTensorType>(b.getType());
int64_t biasAxis = -1;
if (biasType.getRank() == 1)
biasAxis = 0;
else if (biasType.getRank() == 2)
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
else {
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
<< biasType.getRank();
return failure();
}
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
}
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
return failure();
}
SmallVector<Value> groupResults;
groupResults.reserve(group);
auto groupOutType =
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
for (int64_t groupId = 0; groupId < group; groupId++) {
Value groupX = xSlices[groupId];
Value groupW = wSlices[groupId];
Value groupB = hasB ? bSlices[groupId] : noBias;
groupResults.push_back(lowerSingleConvGroup(groupX,
groupW,
groupB,
cast<RankedTensorType>(groupX.getType()),
cast<RankedTensorType>(groupW.getType()),
groupOutType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
});
result = concatCompute.getResult(0);
}
rewriter.replaceOp(convOp, result);
return success();
}
@@ -2,8 +2,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
if (rhsBatchShape.empty())
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
return failure();
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure();
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t batch = std::max(lhsBatch, rhsBatch);
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
if (k != rhsK)
return failure();
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|| outType.getDimSize(outType.getRank() - 1) != n)
return failure();
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value lhs = matmulOp.getA();
Value rhs = matmulOp.getB();
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
}
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getNumElements() != resultType.getNumElements())
return failure();
return replaceWithReshape([&](Value data) -> Value {
Value reshaped = data;
if (sourceType.getRank() != 1) {
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
reshaped = tensor::CollapseShapeOp::create(
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
}
if (resultType.getRank() == 1)
return reshaped;
return tensor::ExpandShapeOp::create(
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
.getResult();
});
return failure();
}
};
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
}
} // namespace onnx_mlir
+79 -45
View File
@@ -7,6 +7,7 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -152,14 +153,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
bool hasFailure = false;
pim::CappedDiagnosticReporter diagnostics;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation reached PIM codegen verification");
hasFailure = true;
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
});
});
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
@@ -168,36 +170,36 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
hasFailure = true;
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
(void) verifyCoreOperands(coreOp, diagnostics);
continue;
}
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
hasFailure = true;
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
(void) verifyCoreOperands(coreBatchOp, diagnostics);
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
if (failed(verifyReturnOp(returnOp)))
hasFailure = true;
(void) verifyReturnOp(returnOp, diagnostics);
continue;
}
if (!isAddressOnlyHostOp(&op)) {
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
hasFailure = true;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
});
continue;
}
if (failed(verifyAddressOnlyHostOp(&op)))
hasFailure = true;
(void) verifyAddressOnlyHostOp(&op, diagnostics);
}
}
if (hasFailure) {
if (diagnostics.hasFailure()) {
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure();
}
@@ -205,14 +207,19 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
private:
template <typename CoreOpTy>
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
static LogicalResult
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
for (auto it : llvm::enumerate(coreOp.getWeights())) {
size_t weightIndex = it.index();
Value weight = it.value();
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp && !isConstantGlobalView(weight)) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as a constant memref.global or a static view of one before JSON "
"codegen";
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as a constant memref.global or a static view of one before "
"JSON codegen";
});
hasFailure = true;
continue;
}
@@ -222,14 +229,18 @@ private:
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
});
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
});
hasFailure = true;
}
}
@@ -237,11 +248,15 @@ private:
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
for (auto it : llvm::enumerate(returnOp.getOperands())) {
size_t resultIndex = it.index();
Value operand = it.value();
if (!isCodegenAddressableValue(operand)) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
});
hasFailure = true;
}
}
@@ -249,38 +264,50 @@ private:
}
template <typename CoreOpTy>
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false;
if (!isSupportedCoreInstructionOp(&op)) {
op.emitOpError("unsupported executable op reached PIM codegen verification");
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
});
hasFailure = true;
}
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
for (auto it : llvm::enumerate(op.getOperands())) {
size_t operandIndex = it.index();
Value operand = it.value();
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) {
op.emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with "
"pim.memcp_hd";
});
hasFailure = true;
}
}
@@ -288,18 +315,20 @@ private:
});
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlyBase(op, subviewOp.getSource());
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc());
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc());
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
op->emitOpError("depends on a value that is not backed by addressable storage");
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure();
}
return success();
@@ -307,19 +336,24 @@ private:
return success();
}
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
static LogicalResult
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isCodegenAddressableValue(source))
return success();
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
});
return failure();
}
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isBaseAddressableValue(source))
return success();
op->emitOpError("depends on a value that is not backed by addressable storage");
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure();
}
};