finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled

use uniqued constant helpers everywhere
materialize transposed constants directly
This commit is contained in:
NiccoloN
2026-05-29 17:05:45 +02:00
parent 819d8af0f7
commit 8bb0babf1b
32 changed files with 300 additions and 467 deletions
-6
View File
@@ -749,18 +749,12 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
} // namespace } // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge); return resolveIndexValueImpl(value, &knowledge);
} }
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); } llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value, llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) { const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge); return resolveContiguousAddressImpl(value, &knowledge);
+2 -4
View File
@@ -77,14 +77,12 @@ mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::m
/// Resolves a value to contiguous backing storage when that storage can be /// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews. /// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value, llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge); const StaticValueKnowledge& knowledge = {});
/// Statically evaluates index-like SSA values, including simple integer /// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`. /// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value); llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value); llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a /// Follows alias, view, and DPS chains to recover the backing value of a
+14
View File
@@ -17,4 +17,18 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
return laneCoreIds; return laneCoreIds;
} }
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
}
} // namespace onnx_mlir } // namespace onnx_mlir
+4
View File
@@ -11,4 +11,8 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane); llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
} // namespace onnx_mlir } // namespace onnx_mlir
+16 -22
View File
@@ -14,13 +14,17 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) { Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp()) for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current)) if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock(); return current->getBlock();
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>()) if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front(); return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>()) if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
@@ -28,9 +32,9 @@ Block* getHostConstantBlock(Operation* anchorOp) {
return anchorOp->getBlock(); return anchorOp->getBlock();
} }
Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp); Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) { for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op); auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
@@ -42,9 +46,9 @@ Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attr
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
} }
Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp); Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) { for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op); auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
@@ -57,28 +61,18 @@ Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attri
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult(); return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
} }
Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
} }
Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() ); return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
} }
Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext()); Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() );
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() );
} }
Value createAffineApplyOrFoldedConstant( Value createAffineApplyOrFoldedConstant(
@@ -95,7 +89,7 @@ Value createAffineApplyOrFoldedConstant(
SmallVector<Attribute> foldedResults; SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) { if (succeeded(map.constantFold(operandConstants, foldedResults))) {
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front())) if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt()); return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt());
} }
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
+12 -16
View File
@@ -8,27 +8,23 @@
namespace onnx_mlir { namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp); mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder, mlir::Value getOrCreateConstant(mlir::OperationFolder& folder,
mlir::Operation* anchorOp, mlir::Operation* anchorOp,
mlir::Attribute value, mlir::Attribute value,
mlir::Type type); mlir::Type type);
mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter, mlir::Value getOrCreateConstant(mlir::RewriterBase& rewriter,
mlir::Operation* anchorOp, mlir::Operation* anchorOp,
mlir::Attribute value, mlir::Attribute value,
mlir::Type type); mlir::Type type);
mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp); mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value);
mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
mlir::Location loc, mlir::Location loc,
-4
View File
@@ -178,10 +178,6 @@ std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::V
return std::nullopt; return std::nullopt;
} }
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
return resolveWeightIndex(weightOwner, vmmOp.getWeight());
}
llvm::FailureOr<ResolvedWeightView> llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) { resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
llvm::SmallVector<mlir::Operation*> viewOps; llvm::SmallVector<mlir::Operation*> viewOps;
-1
View File
@@ -46,7 +46,6 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight); std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
llvm::FailureOr<ResolvedWeightView> llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {}); resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
@@ -40,14 +40,6 @@ static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr,
return normalizedAxes; return normalizedAxes;
} }
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
}
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
return normalizeAxesImpl(axesAttr, rank);
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) { FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank); SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes) for (int64_t axis : normalizedAxes)
@@ -56,11 +48,7 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
return normalizedAxes; return normalizedAxes;
} }
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) { Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
}
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
@@ -68,22 +56,22 @@ Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, Affin
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) { Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
if (multiplier == 0) if (multiplier == 0)
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0); return getOrCreateIndexConstant(rewriter, anchorOp, 0);
if (multiplier == 1) if (multiplier == 1)
return value; return value;
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
} }
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
if (divisor == 1) if (divisor == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value}); return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
} }
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
@@ -92,12 +80,12 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
} }
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) { Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value)) if (auto attr = dyn_cast<Attribute>(value))
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt()); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
return cast<Value>(value); return cast<Value>(value);
} }
@@ -19,18 +19,12 @@ mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize); int64_t normalizeIndex(int64_t index, int64_t dimSize);
llvm::SmallVector<int64_t> normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank);
llvm::SmallVector<int64_t> normalizeAxes(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank); mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter, mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::AffineExpr expr, mlir::AffineExpr expr,
mlir::ValueRange operands); mlir::ValueRange operands);
mlir::Value mlir::Value
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier); multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
@@ -40,6 +34,6 @@ mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location l
mlir::Value mlir::Value
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value); mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -10,6 +10,7 @@
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "IndexingUtils.hpp" #include "IndexingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
@@ -19,10 +20,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
return getOrMaterializeIndexValue(rewriter, loc, result);
}
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
APInt lhsConst; APInt lhsConst;
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero()) if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
@@ -43,11 +40,12 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult(); return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
if (factorConst.isZero()) if (factorConst.isZero())
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
if (factorConst.isOne()) if (factorConst.isOne())
return value; return value;
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult(); auto factorValue =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue());
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
} }
@@ -61,8 +59,6 @@ int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {}); return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
} }
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) { SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape; SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size()); permutedShape.reserve(permutation.size());
@@ -226,49 +222,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
return slicesPerCore; return slicesPerCore;
} }
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
size_t numHSlices = hSlices.size();
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
Value hSlice = hSlices[hSliceId];
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
}
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isCompileTimeComputable(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
}
Value materializeContiguousTensorSlice(Value source, Value materializeContiguousTensorSlice(Value source,
RankedTensorType resultType, RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> offsets,
@@ -294,7 +247,7 @@ Value materializeContiguousTensorSlice(Value source,
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SmallVector<Value> zeroIndices(resultType.getRank()); SmallVector<Value> zeroIndices(resultType.getRank());
for (Value& zeroIndex : zeroIndices) for (Value& zeroIndex : zeroIndices)
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
SmallVector<Value> resultIndices; SmallVector<Value> resultIndices;
resultIndices.reserve(resultType.getRank()); resultIndices.reserve(resultType.getRank());
@@ -304,7 +257,7 @@ Value materializeContiguousTensorSlice(Value source,
SmallVector<Value> sourceIndices; SmallVector<Value> sourceIndices;
sourceIndices.reserve(resultType.getRank()); sourceIndices.reserve(resultType.getRank());
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) { for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc); Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]);
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc); Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc)); sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
} }
@@ -337,8 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
} }
Value lower = zeroIndices[dim]; Value lower = zeroIndices[dim];
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult(); Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult(); Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator}); auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
resultIndices.push_back(loop.getInductionVar()); resultIndices.push_back(loop.getInductionVar());
@@ -352,17 +305,6 @@ Value materializeContiguousTensorSlice(Value source,
return buildLoopNest(buildLoopNest, 0, init); return buildLoopNest(buildLoopNest, 0, init);
} }
Value extractStaticSlice(PatternRewriter& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets) {
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
getUnitStrides(rewriter, resultType.getRank()))
.getResult();
}
Value extractAxisSlice( Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType()); auto sourceType = cast<RankedTensorType>(source.getType());
@@ -18,41 +18,6 @@
namespace onnx_mlir { namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t; using HSliceId = size_t;
using CoreId = size_t; using CoreId = size_t;
@@ -89,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1; return shape.size() == 2 && shape[0] == 1;
} }
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) { inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape(); return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
} }
@@ -117,8 +71,6 @@ bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape); int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
int64_t getStaticShapeElementCount(mlir::RankedTensorType type);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation); llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation); llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
@@ -156,20 +108,6 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore( llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value materializeContiguousTensorSlice(mlir::Value source, mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::RankedTensorType resultType, mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets, llvm::ArrayRef<mlir::OpFoldResult> offsets,
@@ -177,12 +115,6 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::ConversionPatternRewriter& rewriter, mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter, mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::Value source, mlir::Value source,
@@ -8,10 +8,10 @@
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -38,13 +38,6 @@ static bool isStaticTensorResult(Operation* op) {
}); });
} }
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) { static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType()); auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType) if (!tensorType)
@@ -61,9 +61,9 @@ static Value createPaddedRows(Value tensorValue,
padBlock->addArgument(rewriter.getIndexType(), loc); padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock); padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock); rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create( auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType())); tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero.getResult()); tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp); rewriter.setInsertionPointAfter(padOp);
return padOp.getResult(); return padOp.getResult();
} }
@@ -106,7 +106,7 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
} }
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
} }
static Value createConvWeightMatrix(Value w, static Value createConvWeightMatrix(Value w,
@@ -158,7 +158,7 @@ static Value buildPackedBias(bool hasBias,
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
} }
static Value createIm2colRowComputes(Value x, static Value createIm2colRowComputes(Value x,
@@ -214,8 +214,8 @@ static Value createIm2colRowComputes(Value x,
padBlock->addArgument(rewriter.getIndexType(), loc); padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock); padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock); rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType);
tensor::YieldOp::create(rewriter, loc, zero.getResult()); tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp); rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult(); paddedInput = padOp.getResult();
} }
@@ -223,13 +223,14 @@ static Value createIm2colRowComputes(Value x,
// Build im2col [numPatches, patchSize] incrementally to keep the IR small // Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step. // until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody()); rewriter.setInsertionPointToStart(im2colLoop.getBody());
@@ -83,7 +83,7 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
} }
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues); auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType);
} }
static FailureOr<Value> static FailureOr<Value>
@@ -121,7 +121,7 @@ static FailureOr<Value> materializeReciprocalTensor(Value value,
} }
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues); auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType);
} }
template <typename OnnxOp, typename SpatialOp> template <typename OnnxOp, typename SpatialOp>
@@ -50,38 +50,17 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return failure(); return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues); auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
}
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
}
static Value
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
}
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
}
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
return modIndexByConstant(lane, numOutRows, rewriter, loc);
} }
static Value createGemmBatchKOffset( static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1) if (numKSlices == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant( return createAffineApplyOrFoldedConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
} }
@@ -92,11 +71,11 @@ static Value createGemmBatchHOffset(Value lane,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (numOutHSlices == 1) if (numOutHSlices == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant( return createAffineApplyOrFoldedConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
} }
@@ -115,9 +94,9 @@ createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatte
padBlock->addArgument(rewriter.getIndexType(), loc); padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock); padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock); rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create( auto zero = getOrCreateConstant(
rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType())); rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero.getResult()); tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp); rewriter.setInsertionPointAfter(padOp);
return padOp.getResult(); return padOp.getResult();
} }
@@ -149,7 +128,7 @@ static FailureOr<Value> materializePaddedConstantMatrix(Value value,
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col]; resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
auto resultAttr = DenseElementsAttr::get(resultType, resultValues); auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
} }
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value, static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
@@ -215,7 +194,7 @@ static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
} }
auto resultAttr = DenseElementsAttr::get(resultType, resultValues); auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
} }
static FailureOr<Value> prepareBias(Value c, static FailureOr<Value> prepareBias(Value c,
@@ -274,7 +253,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
const int64_t laneCount = partialPiecesType.getDimSize(0); const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch( auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc); Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -312,12 +291,7 @@ static Value createDynamicGemmBatchRow(
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value createDynamicGemmBatchColumn(
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
return modIndexByConstant(lane, numOutCols, rewriter, loc);
} }
static Value static Value
@@ -385,7 +359,7 @@ static Value createScalarTensorConstant(RankedTensorType scalarType,
auto elementType = scalarType.getElementType(); auto elementType = scalarType.getElementType();
auto scalarAttr = rewriter.getFloatAttr(elementType, value); auto scalarAttr = rewriter.getFloatAttr(elementType, value);
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr); auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
} }
static Value createBroadcastedBiasScalar(Value bias, static Value createBroadcastedBiasScalar(Value bias,
@@ -435,7 +409,7 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
auto batchOp = createSpatComputeBatch( auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc); Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -475,16 +449,16 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value biasArg = bias ? blockArgs[1] : Value(); Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
Value lane = loop.getInductionVar(); Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front(); Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc); Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -522,7 +496,7 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) { Location loc) {
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant( return createAffineApplyOrFoldedConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
} }
@@ -604,7 +578,9 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced = Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc); Value hOffset =
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
crossbarSize.getValue());
if (biasArg) { if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset}; SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice = Value biasSlice =
@@ -620,13 +596,14 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value paddedOutput = outputInit; Value paddedOutput = outputInit;
if (numOutHSlices == 1) { if (numOutHSlices == 1) {
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice); paddedOutput = buildOutputSlice(outputInit, hSlice);
} }
else { else {
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); Value cOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(hLoop.getBody()); rewriter.setInsertionPointToStart(hLoop.getBody());
@@ -763,7 +740,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
if (gemmOpAdaptor.getTransB()) { if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc); b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType()); bType = cast<RankedTensorType>(b.getType());
} }
@@ -76,7 +76,7 @@ static Value computeLaneIndex(Value lane,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (dimSize == 1) if (dimSize == 1)
return arith::ConstantIndexOp::create(rewriter, loc, 0); return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext(); MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d0 = getAffineDimExpr(0, context);
@@ -85,7 +85,7 @@ static Value computeLaneIndex(Value lane,
expr = expr.floorDiv(stride); expr = expr.floorDiv(stride);
if (dimSize != 1) if (dimSize != 1)
expr = expr % dimSize; expr = expr % dimSize;
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane}); return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane});
} }
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input, static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
@@ -236,7 +236,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
Location loc) { Location loc) {
if (resultType.getRank() == 0) { if (resultType.getRank() == 0) {
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(), SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
arith::ConstantIndexOp::create(rewriter, loc, 0)); getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices); Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
} }
@@ -268,7 +268,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
return success(); return success();
} }
auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank()); auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
if (failed(axes)) if (failed(axes))
return failure(); return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank()); SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
@@ -31,17 +31,18 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
static Value static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
if (!useMinimumValue) if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
if (auto floatType = dyn_cast<FloatType>(elementType)) { if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
} }
if (auto integerType = dyn_cast<IntegerType>(elementType)) { if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth()); auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
} }
llvm_unreachable("unsupported pool element type"); llvm_unreachable("unsupported pool element type");
@@ -148,7 +149,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
} }
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues); auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
} }
template <typename PoolOp> template <typename PoolOp>
@@ -265,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth); Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth); Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
rewriter.setInsertionPointToStart(outputLoop.getBody()); rewriter.setInsertionPointToStart(outputLoop.getBody());
@@ -296,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value paddedInH = windowBaseH; Value paddedInH = windowBaseH;
if (kernelH * dilationHeight != 0) { if (kernelH * dilationHeight != 0) {
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
} }
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value paddedInW = windowBaseW; Value paddedInW = windowBaseW;
if (kernelW * dilationWidth != 0) { if (kernelW * dilationWidth != 0) {
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
} }
@@ -52,9 +52,10 @@ static Value buildLoopSoftmaxNest(Value input,
if (axis == inputType.getRank() - 1) if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc); return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis)); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}); auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
@@ -17,9 +17,10 @@ namespace {
static Value buildNearestAsymmetricIndex( static Value buildNearestAsymmetricIndex(
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) { Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim); Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim);
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1); Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim);
Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1);
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim); Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim); Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast); return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
@@ -37,12 +38,13 @@ static Value buildNearestResizeLoop(Value input,
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0)); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1)); Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0));
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2)); Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1));
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3)); Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2));
Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3));
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType); Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
@@ -1,9 +1,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -30,6 +32,54 @@ static Value createTransposeInit(Value input,
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult(); return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
} }
static FailureOr<Value> materializeTransposedConstant(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
auto denseAttr = getHostConstDenseElementsAttr(input);
if (!denseAttr)
return failure();
auto inputType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape()
|| inputType.getRank() != resultType.getRank()
|| static_cast<int64_t>(permutation.size()) != inputType.getRank()) {
return failure();
}
if (denseAttr.isSplat())
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>()),
resultType);
SmallVector<Attribute> inputValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(inputValues.size());
SmallVector<int64_t> inputStrides = computeRowMajorStrides(inputType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<int64_t> inputIndices(inputType.getRank(), 0);
for (auto [linearIndex, value] : llvm::enumerate(inputValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim];
remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim];
}
int64_t resultLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim)
resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim];
resultValues[resultLinearIndex] = value;
}
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, resultValues),
resultType);
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> { struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -44,6 +94,14 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank()); auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
if (failed(permutation)) if (failed(permutation))
return failure(); return failure();
if (isCompileTimeComputable(adaptor.getData())) {
auto constantTranspose =
materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
if (succeeded(constantTranspose)) {
rewriter.replaceOp(transposeOp, *constantTranspose);
return success();
}
}
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
Value transposed = Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation) linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation)
@@ -7,6 +7,7 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -18,15 +19,9 @@ using namespace onnx_mlir::pim;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static bool isUsedOnlyAsExplicitHostOperand(Value value) { static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) { return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber()); return isExplicitDevToHostTargetOperand(use.getOwner(), use.getOperandNumber());
}); });
} }
@@ -55,7 +50,7 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
if (scale == 1) if (scale == 1)
return base; return base;
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult(); auto scaleValue = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scale);
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
} }
@@ -77,7 +72,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
if (auto attr = dyn_cast<Attribute>(offset)) { if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr); auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute"); assert(intAttr && "expected integer offset attribute");
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult(); scaledOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
} }
else { else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale); scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
@@ -88,7 +84,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
} }
if (!totalOffset) if (!totalOffset)
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); totalOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
return totalOffset; return totalOffset;
} }
@@ -214,7 +210,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType()); auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult(); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
pim::PimMemCopyDevToHostOp::create(rewriter, pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(), insertSlice.getLoc(),
hostTarget.getType(), hostTarget.getType(),
@@ -254,7 +250,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand)) if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue; continue;
if (isExplicitHostOperand(&op, operandIndex)) if (isExplicitDevToHostTargetOperand(&op, operandIndex))
continue; continue;
Operation* definingOp = operand.getDefiningOp(); Operation* definingOp = operand.getDefiningOp();
@@ -40,7 +40,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
continue; continue;
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) { if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -218,7 +218,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
continue; continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) { if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantFolder, constantOp)); blockArg->replaceAllUsesWith(getOrCreateConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -230,8 +230,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
PimMemCopyHostToDevOp::create(rewriter, PimMemCopyHostToDevOp::create(rewriter,
loc, loc,
outputBuffer.getType(), outputBuffer.getType(),
getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
outputBuffer, outputBuffer,
input, input,
getTensorSizeInBytesAttr(rewriter, input)) getTensorSizeInBytesAttr(rewriter, input))
@@ -16,25 +16,9 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
populateTransposeLoweringPatterns(patterns); populateTransposeLoweringPatterns(patterns);
} }
void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) {
populateGlobalTensorMaterializationPatterns(patterns);
}
void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCoreBodyPatterns(RewritePatternSet& patterns) { void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns); raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns); populateTransposeLoweringPatterns(patterns);
} }
void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCommunicationPatterns(RewritePatternSet& patterns) {
populateChannelLoweringPatterns(patterns);
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -9,11 +9,7 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateInitialPatterns(mlir::RewritePatternSet& patterns); void populateInitialPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns);
void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns); void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCommunicationPatterns(mlir::RewritePatternSet& patterns);
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns); void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns); void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
@@ -326,7 +326,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
continue; continue;
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) { if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
continue; continue;
} }
@@ -370,8 +370,8 @@ static Value emitHostCopy(IRRewriter& rewriter,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, hostTargetOffset); Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset);
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, deviceSourceOffset); Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
return PimMemCopyDevToHostOp::create(rewriter, return PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
@@ -81,7 +81,7 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType))); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>()) if (outputBuffer->getParentOfType<PimCoreBatchOp>())
@@ -160,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
RewritePatternSet globalTensorPatterns(ctx); RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns); populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -190,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
RewritePatternSet initialTensorPackingPatterns(ctx); RewritePatternSet initialTensorPackingPatterns(ctx);
populateInitialTensorPackingPatterns(initialTensorPackingPatterns); populateTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter); eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -250,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
eraseOpsToRemove(); eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx); RewritePatternSet finalTensorPackingPatterns(ctx);
populateFinalTensorPackingPatterns(finalTensorPackingPatterns); populateTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter); eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -270,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
spatial::SpatExtractRowsOp>(); spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx); RewritePatternSet communicationPatterns(ctx);
populateCommunicationPatterns(communicationPatterns); populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure(); signalPassFailure();
@@ -333,8 +333,8 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
rewriter, rewriter,
loc, loc,
tensorType, tensorType,
getOrCreateHostIndexConstant(constantFolder, deviceTensor.getOperation(), 0), getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
getOrCreateHostIndexConstant(constantFolder, getOrCreateIndexConstant(constantFolder,
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ), deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ),
deviceTensor, deviceTensor,
inputTensor, inputTensor,
+2 -11
View File
@@ -9,6 +9,7 @@
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -19,16 +20,6 @@ namespace pim {
namespace { namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) { static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value)) if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion(); return blockArgument.getParentRegion();
@@ -63,7 +54,7 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
for (OpOperand& operand : op->getOpOperands()) { for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get(); Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value) if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber())) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue; continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError() InFlightDiagnostic diagnostic = ownerOp->emitOpError()
@@ -618,10 +618,6 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
llvm_unreachable("Cannot reach here"); llvm_unreachable("Cannot reach here");
} }
Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) {
return getOrCreateHostIndexConstant(state.constantFolder, anchor, value);
}
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Tensor packing helpers. // Tensor packing helpers.
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -681,7 +677,7 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in
if (dim0Size == 1) if (dim0Size == 1)
return index; return index;
Value dim0SizeValue = createIndexConstant(state, anchor, dim0Size); Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size);
return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult();
} }
@@ -731,7 +727,7 @@ std::optional<Value> extractPackedProducerSlice(MaterializerState& state,
state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); state.rewriter.setInsertionPoint(materializedClass.body->getTerminator());
Value firstOffset = createIndexConstant(state, materializedClass.op, rowOffset); Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset);
return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount);
} }
@@ -754,7 +750,7 @@ Value getPackedSliceForRunIndex(MaterializerState& state,
size_t index, size_t index,
Location loc) { Location loc) {
int64_t rowOffset = static_cast<int64_t>(index) * fragmentType.getDimSize(0); int64_t rowOffset = static_cast<int64_t>(index) * fragmentType.getDimSize(0);
Value firstOffset = createIndexConstant(state, anchor, rowOffset); Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset);
return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0));
} }
@@ -939,7 +935,7 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType()); auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
auto attr = DenseIntElementsAttr::get(type, elements); auto attr = DenseIntElementsAttr::get(type, elements);
return getOrCreateHostConstant(state.constantFolder, anchor, attr, type); return getOrCreateConstant(state.constantFolder, anchor, attr, type);
} }
bool allEqual(ArrayRef<int64_t> values) { bool allEqual(ArrayRef<int64_t> values) {
@@ -1041,7 +1037,7 @@ Value createIndexedIndexValue(
assert(!values.empty() && "expected at least one indexed value"); assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values)) if (allEqual(values))
return createIndexConstant(state, anchor, values.front()); return getOrCreateIndexConstant(state.constantFolder, anchor, values.front());
if (std::optional<IndexedIndexPattern> pattern = getIndexedIndexPattern(values)) if (std::optional<IndexedIndexPattern> pattern = getIndexedIndexPattern(values))
return createAffineIndexValue(state, *pattern, index, loc); return createAffineIndexValue(state, *pattern, index, loc);
@@ -1110,7 +1106,7 @@ Value createOriginalLaneValue(MaterializerState& state,
Location loc) { Location loc) {
assert(!peers.empty() && "expected at least one peer instance"); assert(!peers.empty() && "expected at least one peer instance");
if (!materializedClass.isBatch) if (!materializedClass.isBatch)
return createIndexConstant(state, materializedClass.op, peers.front().laneStart); return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart);
auto batch = cast<SpatComputeBatch>(materializedClass.op); auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument(); auto laneArg = batch.getLaneArgument();
@@ -1465,9 +1461,9 @@ void appendScalarSend(MaterializerState& state,
assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId);
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
} }
@@ -1485,9 +1481,9 @@ void appendScalarSendLoop(MaterializerState& state,
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value lowerBound = createIndexConstant(state, sourceClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
@@ -1514,9 +1510,9 @@ Value buildProjectedPackedPayload(MaterializerState& state,
state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType())
.getResult(); .getResult();
Value lowerBound = createIndexConstant(state, anchor, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0);
Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane);
Value step = createIndexConstant(state, anchor, 1); Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
@@ -1531,7 +1527,7 @@ Value buildProjectedPackedPayload(MaterializerState& state,
Value fragmentIndex = loop.getInductionVar(); Value fragmentIndex = loop.getInductionVar();
Value acc = body->getArgument(1); Value acc = body->getArgument(1);
Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane);
Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
@@ -1562,13 +1558,14 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (channelIds.size() == 1) { if (channelIds.size() == 1) {
Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front()); Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelIds.front());
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front()); Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreIds.front());
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front()); Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreIds.front());
Value laneIndex = createIndexConstant(state, sourceClass.op, 0); Value laneIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
Value sendPayload; Value sendPayload;
if (descriptor.fragmentsPerLane == 1) { if (descriptor.fragmentsPerLane == 1) {
Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front()); Value offset =
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front());
sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0));
} }
else { else {
@@ -1579,9 +1576,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
return; return;
} }
Value lowerBound = createIndexConstant(state, sourceClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, sourceClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
@@ -1645,9 +1642,9 @@ Value appendScalarReceive(MaterializerState& state,
assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class");
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, targetClass.op, channelId); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId);
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue)
.getOutput(); .getOutput();
} }
@@ -2132,9 +2129,9 @@ FailureOr<Value> materializeDeferredLocalPackedScalarRunValue(MaterializerState&
Value init = Value init =
tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult();
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(keys.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(keys.size()));
Value step = createIndexConstant(state, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
@@ -2198,9 +2195,9 @@ FailureOr<Value> insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt
SmallVector<size_t, 1> resultIndices {run.resultIndex}; SmallVector<size_t, 1> resultIndices {run.resultIndex};
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(keys.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(keys.size()));
Value step = createIndexConstant(state, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination});
@@ -2262,9 +2259,10 @@ FailureOr<Value> insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState&
if (outputOffsets.size() != run.channelIds.size()) if (outputOffsets.size() != run.channelIds.size())
return failure(); return failure();
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(run.channelIds.size())); Value upperBound =
Value step = createIndexConstant(state, targetClass.op, 1); getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.channelIds.size()));
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination});
@@ -2343,9 +2341,9 @@ FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state,
slotRowOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane); slotRowOffsets.push_back(static_cast<int64_t>(slotKey->instance.laneStart) * plan.rowsPerLane);
} }
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(run.slots.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.slots.size()));
Value step = createIndexConstant(state, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination});
@@ -2507,7 +2505,7 @@ FailureOr<Value> emitWholeBatchAssemblyPlan(MaterializerState& state,
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
int64_t rowOffset = static_cast<int64_t>(fragment.key.instance.laneStart) * plan.rowsPerLane; int64_t rowOffset = static_cast<int64_t>(fragment.key.instance.laneStart) * plan.rowsPerLane;
Value outputOffset = createIndexConstant(state, targetClass.op, rowOffset); Value outputOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset);
result = insertFragmentIntoWholeBatch(state, fragment.fragment, result, outputOffset, loc); result = insertFragmentIntoWholeBatch(state, fragment.fragment, result, outputOffset, loc);
} }
@@ -3050,7 +3048,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
const ComputeInstance& instance = run.front().peers.front(); const ComputeInstance& instance = run.front().peers.front();
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value laneValue = createIndexConstant(state, targetClass.op, instance.laneStart); Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, instance.laneStart);
return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices); return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices);
} }
@@ -3087,9 +3085,9 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
laneStarts.push_back(instance.laneStart); laneStarts.push_back(instance.laneStart);
} }
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(run.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.size()));
Value step = createIndexConstant(state, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues)); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues));
@@ -3563,9 +3561,9 @@ LogicalResult materializeBatchClassRun(MaterializerState& state,
if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans)))
return failure(); return failure();
Value lowerBound = createIndexConstant(state, targetClass.op, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
Value upperBound = createIndexConstant(state, targetClass.op, static_cast<int64_t>(run.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.size()));
Value step = createIndexConstant(state, targetClass.op, 1); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
@@ -3669,9 +3667,9 @@ Value createReceiveConcatLoop(MaterializerState& state,
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive"); assert(!channelIds.empty() && "expected at least one receive");
Value lowerBound = createIndexConstant(state, anchor, 0); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0);
Value upperBound = createIndexConstant(state, anchor, static_cast<int64_t>(channelIds.size())); Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, anchor, 1); Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1);
state.rewriter.setInsertionPoint(insertionPoint); state.rewriter.setInsertionPoint(insertionPoint);
Value init = Value init =
@@ -2,6 +2,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "../Common.hpp" #include "../Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -41,7 +42,7 @@ static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffs
} }
auto value = cast<Value>(baseOffset); auto value = cast<Value>(baseOffset);
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset); auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
} }
@@ -75,7 +76,7 @@ static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
Value remaining = linearIndex; Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) { for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride); auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index); indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
@@ -90,7 +91,7 @@ static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset,
if (integerAttr.getInt() == 0) if (integerAttr.getInt() == 0)
return extraOffset; return extraOffset;
auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt()); auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
} }
@@ -195,9 +196,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size()) && sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) { && dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0); auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices); auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1); auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
@@ -284,8 +285,8 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
rewriter, rewriter,
[&]( [&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset); Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset); Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter, pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
@@ -324,8 +325,8 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
rewriter, rewriter,
[&]( [&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset); Value dstOffset = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset); Value srcOffset = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter, pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
@@ -13,6 +13,7 @@
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -22,16 +23,6 @@ namespace onnx_mlir {
namespace { namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
template <typename CoreOpTy> template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter, IRRewriter& rewriter,
@@ -51,7 +42,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
for (OpOperand& operand : op->getOpOperands()) { for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get(); Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber())) if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue; continue;
auto resolvedAddress = resolveContiguousAddress(originalValue); auto resolvedAddress = resolveContiguousAddress(originalValue);
@@ -113,8 +104,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
rewriter, rewriter,
op->getLoc(), op->getLoc(),
originalType, originalType,
getOrCreateHostIndexConstant(constantFolder, op, 0), getOrCreateIndexConstant(constantFolder, op, 0),
getOrCreateHostIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ), getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ),
deviceDst, deviceDst,
getGlobalOp.getResult(), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
+2 -11
View File
@@ -6,6 +6,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
@@ -119,16 +120,6 @@ static bool isConstantGlobalView(Value value) {
} }
} }
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static bool isCoreWeightBlockArgument(Value value) { static bool isCoreWeightBlockArgument(Value value) {
auto blockArgument = dyn_cast<BlockArgument>(value); auto blockArgument = dyn_cast<BlockArgument>(value);
if (!blockArgument) if (!blockArgument)
@@ -361,7 +352,7 @@ private:
continue; continue;
} }
if (isExplicitHostOperand(&op, operandIndex)) { if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) { if (!isCodegenAddressableValue(operand, knowledge)) {
diagnostics.report(&op, [&](Operation* illegalOp) { diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex illegalOp->emitOpError() << "host operand #" << operandIndex