Bye Bye DCP
This commit is contained in:
@@ -4,7 +4,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
ConversionPatterns.cpp
|
||||
HostFoldability.cpp
|
||||
CompileTime.cpp
|
||||
HostLegality.cpp
|
||||
PrePatterns.cpp
|
||||
PostPatterns.cpp
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -44,7 +44,7 @@ SmallVector<Value> sliceTensor(
|
||||
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||
|
||||
Value slice;
|
||||
if (isHostFoldableValue(tensorToSlice)) {
|
||||
if (isCompileTimeComputable(tensorToSlice)) {
|
||||
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
}
|
||||
else {
|
||||
@@ -113,7 +113,7 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
if (isCompileTimeComputable(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
|
||||
+32
-23
@@ -8,7 +8,7 @@
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -145,7 +145,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
@@ -156,7 +156,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
@@ -169,7 +169,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
@@ -177,7 +177,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
@@ -185,7 +185,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
@@ -195,62 +195,71 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op)
|
||||
return false;
|
||||
|
||||
if (!visited.insert(op).second)
|
||||
return true;
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited);
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isHostFoldableValue(transposeOp.getData());
|
||||
return isCompileTimeOpImpl(transposeOp, visited);
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||
return isCompileTimeOpImpl(collapseShapeOp,visited);
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||
return isCompileTimeOpImpl(expandShapeOp, visited);
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited);
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
return isCompileTimeOpImpl(splatOp, visited);
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
return isCompileTimeOpImpl(extractRowsOp, visited);
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)){
|
||||
bool res = true;
|
||||
for(auto operandValue : concatOp.getOperands()){
|
||||
res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited);
|
||||
if(!res) break;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isHostFoldableValue(Value value) {
|
||||
bool isCompileTimeComputable(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(definingOp, visited);
|
||||
return isCompileTimeOpImpl(definingOp, visited);
|
||||
}
|
||||
|
||||
bool isHostFoldableOp(Operation* op) {
|
||||
bool isCompileTimeOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
return isCompileTimeOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
return getHostConstantDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isCompileTimeComputable(mlir::Value value);
|
||||
|
||||
bool isCompileTimeOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -17,7 +17,7 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isHostFoldableOp(&op))
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -92,7 +92,7 @@ static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
if (!transposeOp || isCompileTimeOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -391,7 +391,7 @@ static Value lowerSingleConvGroup(Value x,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
auto wDenseAttr = getHostConstDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -412,7 +412,7 @@ static Value lowerSingleConvGroup(Value x,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasDenseAttr = getHostConstDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -717,7 +717,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
if (llvm::all_of(groupResults, isCompileTimeComputable)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -55,7 +55,7 @@ static Value transposeForSpatial(Value value,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
if (isCompileTimeComputable(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
@@ -67,7 +67,7 @@ static Value transposeForSpatial(Value value,
|
||||
|
||||
static Value
|
||||
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
if (isCompileTimeComputable(value))
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
@@ -121,7 +121,7 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
if (isHostFoldableValue(matrix)) {
|
||||
if (isCompileTimeComputable(matrix)) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -55,7 +55,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
@@ -114,7 +114,7 @@ static Value extractBatchMatrix(Value value,
|
||||
});
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
auto batchMatrixCompute =
|
||||
@@ -142,7 +142,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
@@ -182,7 +182,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewr
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
@@ -235,7 +235,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -91,7 +91,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewr
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
@@ -135,7 +135,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
}
|
||||
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
if (isCompileTimeComputable(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable)) {
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -115,7 +115,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
}
|
||||
|
||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||
if (isHostFoldableValue(adaptor.getData())) {
|
||||
if (isCompileTimeComputable(adaptor.getData())) {
|
||||
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -61,7 +61,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
sliceSizes.push_back(resultType.getShape()[axis]);
|
||||
}
|
||||
|
||||
if (isHostFoldableValue(adaptor.getInput())) {
|
||||
if (isCompileTimeComputable(adaptor.getInput())) {
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
|
||||
Reference in New Issue
Block a user