Bye Bye DCP

This commit is contained in:
ilgeco
2026-05-25 15:44:30 +02:00
parent 4855a2e105
commit eea9261c7b
42 changed files with 176 additions and 3994 deletions
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -55,7 +55,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
@@ -114,7 +114,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildMatrix(value);
auto batchMatrixCompute =
@@ -142,7 +142,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
@@ -182,7 +182,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewr
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
@@ -235,7 +235,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);