fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages
This commit is contained in:
@@ -7,10 +7,34 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
|
||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -84,6 +85,30 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
@@ -94,7 +119,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXMatMulOp, ONNXFlattenOp>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
@@ -111,6 +136,21 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -161,20 +201,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
||||
<< coresCount << ")";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
if (failed(cleanupPM.run(moduleOp)))
|
||||
@@ -201,6 +227,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
return replaceWithReshape([&](Value data) -> Value {
|
||||
Value reshaped = data;
|
||||
if (sourceType.getRank() != 1) {
|
||||
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||
reshaped = tensor::CollapseShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||
}
|
||||
if (resultType.getRank() == 1)
|
||||
return reshaped;
|
||||
return tensor::ExpandShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||
.getResult();
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -152,14 +153,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
hasFailure = true;
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
});
|
||||
});
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
@@ -168,36 +170,36 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
|
||||
(void) verifyCoreOperands(coreOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||
(void) verifyCoreOperands(coreBatchOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyReturnOp(returnOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
(void) verifyAddressOnlyHostOp(&op, diagnostics);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
if (diagnostics.hasFailure()) {
|
||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
}
|
||||
@@ -205,14 +207,19 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
private:
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
||||
static LogicalResult
|
||||
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (auto it : llvm::enumerate(coreOp.getWeights())) {
|
||||
size_t weightIndex = it.index();
|
||||
Value weight = it.value();
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before JSON "
|
||||
"codegen";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before "
|
||||
"JSON codegen";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
@@ -222,14 +229,18 @@ private:
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -237,11 +248,15 @@ private:
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
size_t resultIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -249,38 +264,50 @@ private:
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
if (!isSupportedCoreInstructionOp(&op)) {
|
||||
op.emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
for (auto it : llvm::enumerate(op.getOperands())) {
|
||||
size_t operandIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with "
|
||||
"pim.memcp_hd";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -288,18 +315,20 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource());
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -307,19 +336,24 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
static LogicalResult
|
||||
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isBaseAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user