Bye Bye DCP

This commit is contained in:
ilgeco
2026-05-25 15:44:30 +02:00
parent 4855a2e105
commit eea9261c7b
42 changed files with 176 additions and 3994 deletions
@@ -4,7 +4,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
CompileTime.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
@@ -6,7 +6,7 @@
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
using namespace mlir;
@@ -44,7 +44,7 @@ SmallVector<Value> sliceTensor(
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
if (isCompileTimeComputable(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
@@ -113,7 +113,7 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isHostFoldableValue(scalarToBroadcast))
if (isCompileTimeComputable(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
@@ -8,7 +8,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -145,7 +145,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
@@ -156,7 +156,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
@@ -169,7 +169,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
@@ -177,7 +177,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
@@ -185,7 +185,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
@@ -195,62 +195,71 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op)
return false;
if (!visited.insert(op).second)
return true;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited);
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
return isCompileTimeOpImpl(transposeOp, visited);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
return isCompileTimeOpImpl(collapseShapeOp,visited);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
return isCompileTimeOpImpl(expandShapeOp, visited);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited);
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
return isCompileTimeOpImpl(splatOp, visited);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
return isCompileTimeOpImpl(extractRowsOp, visited);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)){
bool res = true;
for(auto operandValue : concatOp.getOperands()){
res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited);
if(!res) break;
}
return res;
}
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
return isCompileTimeOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
return isCompileTimeOpImpl(op, visited);
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
return getHostConstantDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,15 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isCompileTimeComputable(mlir::Value value);
bool isCompileTimeOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -3,7 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -17,7 +17,7 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
@@ -13,7 +13,7 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
@@ -92,7 +92,7 @@ static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
@@ -11,7 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -391,7 +391,7 @@ static Value lowerSingleConvGroup(Value x,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
auto wDenseAttr = getHostConstDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -412,7 +412,7 @@ static Value lowerSingleConvGroup(Value x,
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasDenseAttr = getHostConstDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -717,7 +717,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
if (llvm::all_of(groupResults, isCompileTimeComputable)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
@@ -11,7 +11,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -55,7 +55,7 @@ static Value transposeForSpatial(Value value,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
@@ -67,7 +67,7 @@ static Value transposeForSpatial(Value value,
static Value
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
@@ -121,7 +121,7 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
if (isHostFoldableValue(matrix)) {
if (isCompileTimeComputable(matrix)) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
}
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -55,7 +55,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
@@ -114,7 +114,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildMatrix(value);
auto batchMatrixCompute =
@@ -142,7 +142,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
@@ -182,7 +182,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewr
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
@@ -235,7 +235,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -91,7 +91,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewr
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
@@ -135,7 +135,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
}
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
if (isCompileTimeComputable(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) {
if (llvm::all_of(inputs, isCompileTimeComputable)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
@@ -5,7 +5,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -115,7 +115,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
}
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
if (isCompileTimeComputable(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -61,7 +61,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
sliceSizes.push_back(resultType.getShape()[axis]);
}
if (isHostFoldableValue(adaptor.getInput())) {
if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize;