better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled

spat.map
This commit is contained in:
NiccoloN
2026-05-06 12:21:58 +02:00
parent 285773fa55
commit b2dc9c38b6
12 changed files with 1442 additions and 274 deletions
@@ -1,5 +1,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
@@ -35,6 +37,7 @@
#include <vector>
#include "DCPGraph/DCPAnalysis.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -147,7 +150,7 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
}
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
@@ -304,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
if (!coreIds.empty())
newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
auto* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
@@ -548,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
}
}
static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct SendEntry {
spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
@@ -755,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
@@ -1879,6 +1747,9 @@ public:
rebatchEquivalentComputes(func, nextChannelId);
compactScalarChannelRuns(func, nextChannelId);
compactBatchChannelRuns(func);
compactRegularOpRuns(func);
compactRowWiseWvmmRuns(func);
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
@@ -2049,7 +1920,7 @@ private:
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange(weights),
ValueRange(inputs));
rebatched->setAttr(onnx_mlir::kCoreIdAttrName,
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName,
rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount)));
SmallVector<Type> blockArgTypes;
@@ -0,0 +1,577 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <tuple>
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
enum class RegularStepKind {
Wvmm,
VAddLhs,
VAddRhs,
};
struct RegularStep {
RegularStepKind kind;
int32_t weightIndex = 0;
Value invariantOperand;
Type resultType;
};
struct RegularChunk {
Operation* startOp = nullptr;
SmallVector<Operation*> ops;
SmallVector<RegularStep> steps;
Value input;
Value output;
};
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType;
}
static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) {
if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType()
|| lhs.steps.size() != rhs.steps.size()) {
return false;
}
return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps),
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
}
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
RegularChunk chunk;
chunk.startOp = startOp.getOperation();
chunk.input = startOp.getInput();
chunk.output = startOp.getOutput();
chunk.ops.push_back(startOp.getOperation());
chunk.steps.push_back(
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
Value currentValue = startOp.getOutput();
while (currentValue.hasOneUse()) {
Operation* user = *currentValue.getUsers().begin();
if (user->getBlock() != startOp->getBlock())
break;
auto vaddOp = dyn_cast<spatial::SpatVAddOp>(user);
if (!vaddOp)
break;
if (vaddOp.getLhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
else if (vaddOp.getRhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
else
break;
chunk.ops.push_back(vaddOp);
chunk.output = vaddOp.getOutput();
currentValue = vaddOp.getOutput();
}
return chunk;
}
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
auto* block = rewriter.createBlock(
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
rewriter.setInsertionPointToEnd(block);
IRMapping mapping;
mapping.map(anchorChunk.input, block->getArgument(0));
for (Operation* op : anchorChunk.ops) {
Operation* cloned = rewriter.clone(*op, mapping);
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
mapping.map(oldResult, newResult);
}
spatial::SpatYieldOp::create(
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
}
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front();
SmallVector<Value> inputs;
SmallVector<Type> outputTypes;
inputs.reserve(run.size());
outputTypes.reserve(run.size());
for (const RegularChunk& chunk : run) {
inputs.push_back(chunk.input);
outputTypes.push_back(chunk.output.getType());
}
rewriter.setInsertionPoint(anchorChunk.startOp);
auto mapOp =
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
buildRegularMapBody(mapOp, anchorChunk, rewriter);
for (auto [index, chunk] : llvm::enumerate(run)) {
Value output = chunk.output;
output.replaceAllUsesWith(mapOp.getResult(index));
}
SmallVector<Operation*> opsToErase;
for (const RegularChunk& chunk : run)
llvm::append_range(opsToErase, chunk.ops);
for (Operation* op : llvm::reverse(opsToErase))
rewriter.eraseOp(op);
}
} // namespace
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
struct SendEntry {
spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
uint64_t channelId = 0;
};
SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void compactBatchChannelRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>()) {
Block& block = batch.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
outputTypes.reserve(run.size());
for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
outputTypes.push_back(op.getOutput().getType());
}
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
for (auto op : run)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
++it;
continue;
}
}
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendBatchOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Value> inputs;
inputs.reserve(run.size());
for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
inputs.push_back(op.getInput());
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyBatchOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
}
}
++it;
}
}
}
void compactRegularOpRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
auto compactInBlock = [&](Block& block) {
for (auto it = block.begin(); it != block.end();) {
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
if (!startOp) {
++it;
continue;
}
auto anchorChunk = analyzeRegularChunk(startOp);
if (failed(anchorChunk)) {
++it;
continue;
}
SmallVector<RegularChunk> run {*anchorChunk};
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
while (runIt != block.end()) {
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
if (!candidateStart)
break;
auto candidateChunk = analyzeRegularChunk(candidateStart);
if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk))
break;
run.push_back(*candidateChunk);
runIt = std::next(runIt, static_cast<std::ptrdiff_t>(candidateChunk->ops.size()));
}
if (run.size() <= 1) {
++it;
continue;
}
compactRegularChunkRun(rewriter, run);
it = runIt;
}
};
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
compactInBlock(compute.getBody().front());
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>())
compactInBlock(batch.getBody().front());
}
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
if (!wvmmOp) {
++it;
continue;
}
auto extractRowsOp = wvmmOp.getInput().getDefiningOp<spatial::SpatExtractRowsOp>();
auto rowResult = dyn_cast<OpResult>(wvmmOp.getInput());
auto outputType = dyn_cast<RankedTensorType>(wvmmOp.getOutput().getType());
if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType
|| !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) {
++it;
continue;
}
SmallVector<spatial::SpatWeightedVMMOp> run;
auto runIt = it;
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType()
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
break;
}
auto currentRow = dyn_cast<OpResult>(current.getInput());
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
break;
run.push_back(current);
++expectedRow;
++runIt;
}
if (run.size() <= 1) {
++it;
continue;
}
if (!run.front().getOutput().hasOneUse()) {
++it;
continue;
}
auto concatUse = run.front().getOutput().getUses().begin();
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
if (!concatOp) {
++it;
continue;
}
unsigned concatStartIndex = concatUse->getOperandNumber();
bool validConcatRun = true;
for (auto [index, op] : llvm::enumerate(run)) {
if (!op.getOutput().hasOneUse()) {
validConcatRun = false;
break;
}
OpOperand& use = *op.getOutput().getUses().begin();
if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) {
validConcatRun = false;
break;
}
}
if (!validConcatRun) {
++it;
continue;
}
auto inputType = dyn_cast<RankedTensorType>(wvmmOp.getInput().getType());
auto sourceType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) {
++it;
continue;
}
int64_t inputCols = inputType.getShape()[1];
int64_t outputCols = outputType.getShape()[1];
if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) {
++it;
continue;
}
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
int64_t runLength = static_cast<int64_t>(run.size());
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
rewriter.setInsertionPoint(run.front());
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
auto packedInit =
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
auto loop =
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{
OpBuilder::InsertionGuard guard(rewriter);
Block* loopBlock = loop.getBody();
rewriter.setInsertionPointToStart(loopBlock);
Value iv = loopBlock->getArgument(0);
Value acc = loopBlock->getArgument(1);
Value sourceRow = iv;
if (firstRow != 0) {
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
}
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
run.front().getLoc(),
inputType,
extractRowsOp.getInput(),
extractOffsets,
extractSizes,
extractStrides);
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto inserted = tensor::InsertSliceOp::create(
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
}
SmallVector<Value> newConcatInputs;
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (operandIndex == concatStartIndex)
newConcatInputs.push_back(loop.getResult(0));
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
newConcatInputs.push_back(operand);
}
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
for (auto op : run)
rewriter.eraseOp(op);
it = loop->getIterator();
++it;
}
}
}
} // namespace onnx_mlir
@@ -0,0 +1,14 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include <cstdint>
namespace onnx_mlir {
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir