diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 5a5a1d1..7e58697 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,3 +1,5 @@ +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -143,7 +145,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { OperationFolder constantFolder(&getContext()); ConversionTarget target(*ctx); - target.addLegalDialect(); + coreBodyTarget.addLegalOp(); + SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { - if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { + if (failed(applyFullConversion(coreOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) { coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core"); signalPassFailure(); return; @@ -232,7 +254,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { - if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { + if (failed(applyFullConversion(coreBatchOp.getOperation(), coreBodyTarget, frozenCoreBodyPatterns))) { coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch"); signalPassFailure(); return; diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 86594a1..10c7591 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -1,5 +1,7 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" @@ -119,10 +121,45 @@ static bool isConstantIndexLike(Value value) { return matchPattern(value, m_ConstantInt(&constantValue)); } +static bool isSupportedLaneAffineExpr(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Constant: + case AffineExprKind::DimId: return true; + case AffineExprKind::SymbolId: return false; + case AffineExprKind::Add: { + auto binaryExpr = cast(expr); + return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()); + } + case AffineExprKind::Mul: { + auto binaryExpr = cast(expr); + return (isa(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS())) + || (isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS())); + } + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: + case AffineExprKind::Mod: { + auto binaryExpr = cast(expr); + return isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()); + } + } + llvm_unreachable("unexpected affine expression kind"); +} + static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { if (value == laneArg || isConstantIndexLike(value)) return true; + auto affineApply = value.getDefiningOp(); + if (affineApply) { + if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0) + return false; + if (!llvm::all_of(affineApply.getMapOperands(), + [&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) { + return false; + } + return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0)); + } + auto extractOp = value.getDefiningOp(); if (extractOp) { auto constantTensor = extractOp.getTensor().getDefiningOp(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index e320916..752eaf0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -6,6 +7,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -15,8 +17,6 @@ #include "llvm/Support/FormatVariadic.h" #include -#include -#include #include #include #include @@ -593,40 +593,90 @@ bool allEqual(ArrayRef values) { return true; } -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); +struct IndexedIndexPattern { + int64_t base = 0; + int64_t step = 0; + int64_t period = 1; + int64_t innerStep = 0; + int64_t outerStep = 0; + bool isTiled = false; +}; - if (allEqual(values)) - return createIndexConstant(state, materializedClass.op, values.front()); +bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected compute_batch lane argument"); + pattern.base = values.front(); + pattern.step = values.size() == 1 ? 0 : values[1] - values[0]; + pattern.isTiled = false; - Value table = createIndexTensorConstant(state, materializedClass.op, values); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); + for (auto [index, value] : llvm::enumerate(values)) { + int64_t expected = pattern.base + pattern.step * static_cast(index); + if (value != expected) + return false; + } + + return true; } -Value createLaneIndexedIndexValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef values, - Location loc) { - assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); - assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); - if (allEqual(values)) - return createIndexConstant(state, materializedClass.op, values.front()); + for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) { + int64_t base = values.front(); + int64_t innerStep = values[1] - values[0]; + int64_t outerStep = values[period] - values[0]; - auto batch = cast(materializedClass.op); - auto laneArg = batch.getLaneArgument(); - assert(laneArg && "expected compute_batch lane argument"); + bool matches = true; + for (auto [index, value] : llvm::enumerate(values)) { + int64_t i = static_cast(index); + int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); + if (value != expected) { + matches = false; + break; + } + } - Value table = createIndexTensorConstant(state, materializedClass.op, values); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); + if (!matches) + continue; + + pattern.base = base; + pattern.period = period; + pattern.innerStep = innerStep; + pattern.outerStep = outerStep; + pattern.isTiled = true; + return true; + } + + return false; +} + +std::optional getIndexedIndexPattern(ArrayRef values) { + assert(!values.empty() && "expected at least one value"); + + IndexedIndexPattern pattern; + if (matchAffineSequence(values, pattern)) + return pattern; + if (matchTiledAffineSequence(values, pattern)) + return pattern; + + return std::nullopt; +} + +Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + + AffineExpr expr; + if (!pattern.isTiled) { + expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step; + } + else { + expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep + + (d0 % pattern.period) * pattern.innerStep; + } + + AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {index}).getResult(); } Value createIndexedIndexValue( @@ -636,6 +686,9 @@ Value createIndexedIndexValue( if (allEqual(values)) return createIndexConstant(state, anchor, values.front()); + if (std::optional pattern = getIndexedIndexPattern(values)) + return createAffineIndexValue(state, *pattern, index, loc); + Value table = createIndexTensorConstant(state, anchor, values); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); } @@ -644,11 +697,41 @@ Value createIndexedIndexValue( MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { assert(!values.empty() && "expected at least one indexed value"); - if (allEqual(values)) - return createIndexConstant(state, anchor, values.front()); + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); - Value table = createIndexTensorConstant(state, anchor, values); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected compute_batch lane argument"); + + return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + + return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); } FailureOr> @@ -675,38 +758,13 @@ Value createOriginalLaneValue(MaterializerState& state, auto batch = cast(materializedClass.op); auto laneArg = batch.getLaneArgument(); assert(laneArg && "expected materialized compute_batch lane argument"); - bool identity = true; - for (auto [lane, peer] : llvm::enumerate(peers)) { - if (peer.laneCount != 1 || peer.laneStart != lane) { - identity = false; - break; - } - } - if (identity) - return *laneArg; - - bool affineWithBase = true; - int64_t base = static_cast(peers.front().laneStart); - for (auto [lane, peer] : llvm::enumerate(peers)) { - if (peer.laneCount != 1 || static_cast(peer.laneStart) != base + static_cast(lane)) { - affineWithBase = false; - break; - } - } - if (affineWithBase) { - if (base == 0) - return *laneArg; - Value baseValue = createIndexConstant(state, materializedClass.op, base); - return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult(); - } SmallVector laneValues; laneValues.reserve(peers.size()); for (const ComputeInstance& peer : peers) laneValues.push_back(peer.laneStart); - Value table = createIndexTensorConstant(state, materializedClass.op, laneValues); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); + return createIndexedIndexValue(state, materializedClass.op, ArrayRef(laneValues), *laneArg, loc); } bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { @@ -1686,6 +1744,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(eraseOldComputeOps(state))) return failure(); + LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); + return success(); }