use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled

run dce at the end of MaterializeMergeSchedule to get rid of unused constants
This commit is contained in:
NiccoloN
2026-05-23 14:25:34 +02:00
parent 76a37e198f
commit b79ce8eeaa
3 changed files with 180 additions and 61 deletions
@@ -1,3 +1,5 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -143,7 +145,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
OperationFolder constantFolder(&getContext()); OperationFolder constantFolder(&getContext());
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, target.addLegalDialect<affine::AffineDialect,
PimDialect,
tensor::TensorDialect, tensor::TensorDialect,
arith::ArithDialect, arith::ArithDialect,
bufferization::BufferizationDialect, bufferization::BufferizationDialect,
@@ -217,12 +220,31 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
RewritePatternSet coreBodyPatterns(ctx); RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns); populateWithGenerated(coreBodyPatterns);
populateAffineToStdConversionPatterns(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
ConversionTarget coreBodyTarget(*ctx);
coreBodyTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
coreBodyTarget.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
SmallVector<pim::PimCoreOp> coreOps; SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) { for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { if (failed(applyFullConversion(coreOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) {
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core"); coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
signalPassFailure(); signalPassFailure();
return; return;
@@ -232,7 +254,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreBatchOp> coreBatchOps; SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) { for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { if (failed(applyFullConversion(coreBatchOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) {
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch"); coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
signalPassFailure(); signalPassFailure();
return; return;
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
@@ -119,10 +121,45 @@ static bool isConstantIndexLike(Value value) {
return matchPattern(value, m_ConstantInt(&constantValue)); return matchPattern(value, m_ConstantInt(&constantValue));
} }
static bool isSupportedLaneAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value)) if (value == laneArg || isConstantIndexLike(value))
return true; return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0)
return false;
if (!llvm::all_of(affineApply.getMapOperands(),
[&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) {
return false;
}
return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0));
}
auto extractOp = value.getDefiningOp<tensor::ExtractOp>(); auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
if (extractOp) { if (extractOp) {
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>(); auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -6,6 +7,7 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
@@ -15,8 +17,6 @@
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <algorithm> #include <algorithm>
#include <cstddef>
#include <cstdint>
#include <limits> #include <limits>
#include <optional> #include <optional>
#include <utility> #include <utility>
@@ -593,40 +593,90 @@ bool allEqual(ArrayRef<int32_t> values) {
return true; return true;
} }
Value createLaneIndexedIndexValue(MaterializerState& state, struct IndexedIndexPattern {
MaterializedClass& materializedClass, int64_t base = 0;
ArrayRef<int64_t> values, int64_t step = 0;
Location loc) { int64_t period = 1;
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); int64_t innerStep = 0;
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); int64_t outerStep = 0;
bool isTiled = false;
};
if (allEqual(values)) bool matchAffineSequence(ArrayRef<int64_t> values, IndexedIndexPattern& pattern) {
return createIndexConstant(state, materializedClass.op, values.front()); assert(!values.empty() && "expected at least one value");
auto batch = cast<SpatComputeBatch>(materializedClass.op); pattern.base = values.front();
auto laneArg = batch.getLaneArgument(); pattern.step = values.size() == 1 ? 0 : values[1] - values[0];
assert(laneArg && "expected compute_batch lane argument"); pattern.isTiled = false;
Value table = createIndexTensorConstant(state, materializedClass.op, values); for (auto [index, value] : llvm::enumerate(values)) {
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); int64_t expected = pattern.base + pattern.step * static_cast<int64_t>(index);
if (value != expected)
return false;
} }
Value createLaneIndexedIndexValue(MaterializerState& state, return true;
MaterializedClass& materializedClass, }
ArrayRef<int32_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
if (allEqual(values)) bool matchTiledAffineSequence(ArrayRef<int64_t> values, IndexedIndexPattern& pattern) {
return createIndexConstant(state, materializedClass.op, values.front()); assert(!values.empty() && "expected at least one value");
auto batch = cast<SpatComputeBatch>(materializedClass.op); for (int64_t period = 2; period <= static_cast<int64_t>(values.size() / 2); ++period) {
auto laneArg = batch.getLaneArgument(); int64_t base = values.front();
assert(laneArg && "expected compute_batch lane argument"); int64_t innerStep = values[1] - values[0];
int64_t outerStep = values[period] - values[0];
Value table = createIndexTensorConstant(state, materializedClass.op, values); bool matches = true;
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); for (auto [index, value] : llvm::enumerate(values)) {
int64_t i = static_cast<int64_t>(index);
int64_t expected = base + outerStep * (i / period) + innerStep * (i % period);
if (value != expected) {
matches = false;
break;
}
}
if (!matches)
continue;
pattern.base = base;
pattern.period = period;
pattern.innerStep = innerStep;
pattern.outerStep = outerStep;
pattern.isTiled = true;
return true;
}
return false;
}
std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> values) {
assert(!values.empty() && "expected at least one value");
IndexedIndexPattern pattern;
if (matchAffineSequence(values, pattern))
return pattern;
if (matchTiledAffineSequence(values, pattern))
return pattern;
return std::nullopt;
}
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr expr;
if (!pattern.isTiled) {
expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step;
}
else {
expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep
+ (d0 % pattern.period) * pattern.innerStep;
}
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {index}).getResult();
} }
Value createIndexedIndexValue( Value createIndexedIndexValue(
@@ -636,6 +686,9 @@ Value createIndexedIndexValue(
if (allEqual(values)) if (allEqual(values))
return createIndexConstant(state, anchor, values.front()); return createIndexConstant(state, anchor, values.front());
if (std::optional<IndexedIndexPattern> pattern = getIndexedIndexPattern(values))
return createAffineIndexValue(state, *pattern, index, loc);
Value table = createIndexTensorConstant(state, anchor, values); Value table = createIndexTensorConstant(state, anchor, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
} }
@@ -644,11 +697,41 @@ Value createIndexedIndexValue(
MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values, Value index, Location loc) { MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values, Value index, Location loc) {
assert(!values.empty() && "expected at least one indexed value"); assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values)) SmallVector<int64_t, 8> widened;
return createIndexConstant(state, anchor, values.front()); widened.reserve(values.size());
for (int32_t value : values)
widened.push_back(value);
Value table = createIndexTensorConstant(state, anchor, values); return createIndexedIndexValue(state, anchor, ArrayRef<int64_t>(widened), index, loc);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); }
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int64_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected compute_batch lane argument");
return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc);
}
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int32_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
SmallVector<int64_t, 8> widened;
widened.reserve(values.size());
for (int32_t value : values)
widened.push_back(value);
return createLaneIndexedIndexValue(state, materializedClass, ArrayRef<int64_t>(widened), loc);
} }
FailureOr<SmallVector<ComputeInstance, 8>> FailureOr<SmallVector<ComputeInstance, 8>>
@@ -675,38 +758,13 @@ Value createOriginalLaneValue(MaterializerState& state,
auto batch = cast<SpatComputeBatch>(materializedClass.op); auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument(); auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected materialized compute_batch lane argument"); assert(laneArg && "expected materialized compute_batch lane argument");
bool identity = true;
for (auto [lane, peer] : llvm::enumerate(peers)) {
if (peer.laneCount != 1 || peer.laneStart != lane) {
identity = false;
break;
}
}
if (identity)
return *laneArg;
bool affineWithBase = true;
int64_t base = static_cast<int64_t>(peers.front().laneStart);
for (auto [lane, peer] : llvm::enumerate(peers)) {
if (peer.laneCount != 1 || static_cast<int64_t>(peer.laneStart) != base + static_cast<int64_t>(lane)) {
affineWithBase = false;
break;
}
}
if (affineWithBase) {
if (base == 0)
return *laneArg;
Value baseValue = createIndexConstant(state, materializedClass.op, base);
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
}
SmallVector<int64_t, 8> laneValues; SmallVector<int64_t, 8> laneValues;
laneValues.reserve(peers.size()); laneValues.reserve(peers.size());
for (const ComputeInstance& peer : peers) for (const ComputeInstance& peer : peers)
laneValues.push_back(peer.laneStart); laneValues.push_back(peer.laneStart);
Value table = createIndexTensorConstant(state, materializedClass.op, laneValues); return createIndexedIndexValue(state, materializedClass.op, ArrayRef<int64_t>(laneValues), *laneArg, loc);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
} }
bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps) { bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps) {
@@ -1686,6 +1744,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
if (failed(eraseOldComputeOps(state))) if (failed(eraseOldComputeOps(state)))
return failure(); return failure();
LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody());
return success(); return success();
} }