finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled

use uniqued constant helpers everywhere
materialize transposed constants directly
This commit is contained in:
NiccoloN
2026-05-29 17:05:45 +02:00
parent 819d8af0f7
commit 8bb0babf1b
32 changed files with 300 additions and 467 deletions
+16 -22
View File
@@ -14,13 +14,17 @@ using namespace mlir;
namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) {
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
@@ -28,9 +32,9 @@ Block* getHostConstantBlock(Operation* anchorOp) {
return anchorOp->getBlock();
}
Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
@@ -42,9 +46,9 @@ Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attr
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
@@ -57,28 +61,18 @@ Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attri
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
}
Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() );
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() );
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() );
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value createAffineApplyOrFoldedConstant(
@@ -95,7 +89,7 @@ Value createAffineApplyOrFoldedConstant(
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt());
return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt());
}
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();