Bye Bye DCP
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -91,7 +91,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewr
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
@@ -135,7 +135,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
}
|
||||
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
if (isCompileTimeComputable(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
|
||||
Reference in New Issue
Block a user