add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled

better reports
refactor for more code-reuse and patter usage
fixes
This commit is contained in:
NiccoloN
2026-05-12 18:17:00 +02:00
parent 4f3570520c
commit 41de3cb150
26 changed files with 930 additions and 385 deletions
@@ -94,7 +94,7 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
@@ -148,7 +148,8 @@ void ONNXToSpatialPass::runOnOperation() {
computeOpsCount++;
if (computeOpsCount > coresCount) {
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
@@ -157,7 +158,7 @@ void ONNXToSpatialPass::runOnOperation() {
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
annotateWeightsConstants(*entryFunc);
@@ -70,85 +70,6 @@ private:
} // namespace
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
rewriter.setInsertionPoint(sendOp);
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
rewriter.eraseOp(sendOp);
}
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
return;
}
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
rewriter.setInsertionPoint(receiveOp);
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, received);
}
static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp);
PimSendTensorOp::create(
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp);
}
static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
rewriter.setInsertionPoint(receiveTensorOp);
auto outputType = cast<ShapedType>(receiveTensorOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult();
Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(),
receiveTensorOp.getOutput().getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(receiveTensorOp, received);
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(extractRowsOp);
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(extractRowsOp, replacements);
}
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
@@ -216,97 +137,6 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
for (auto concatOp : concatOps) {
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
continue;
SmallVector<Value> packedInputs;
bool changed = false;
rewriter.setInsertionPoint(concatOp);
for (unsigned index = 0; index < concatOp.getInputs().size();) {
Value input = concatOp.getInputs()[index];
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
++endIndex;
Value packedInput = createPackedExtractSliceTensor(
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
index = endIndex;
continue;
}
}
auto result = dyn_cast<OpResult>(input);
if (!result) {
packedInputs.push_back(input);
++index;
continue;
}
Operation* owner = result.getOwner();
unsigned startIndex = result.getResultNumber();
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()) {
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
if (!nextResult || nextResult.getOwner() != owner
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
break;
++endIndex;
}
unsigned count = endIndex - index;
Value packedInput;
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
}
else {
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
}
index = endIndex;
}
if (!changed)
continue;
auto newConcat = pim::PimConcatOp::create(
rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
rewriter.eraseOp(op);
};
eraseUnusedOps(tensor::ConcatOp {});
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
@@ -380,7 +210,12 @@ void SpatialToPimPass::runOnOperation() {
}
}
compactSpatialTensorGroups(funcOp, rewriter);
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
@@ -392,37 +227,8 @@ void SpatialToPimPass::runOnOperation() {
markOpToRemove(receiveOp);
continue;
}
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
continue;
}
lowerChannelReceive(receiveOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveTensorOp> receiveTensorOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveTensorOp>())
receiveTensorOps.push_back(op);
for (auto receiveTensorOp : receiveTensorOps)
lowerChannelReceiveTensor(receiveTensorOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> sendOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
sendOps.push_back(op);
for (auto sendOp : sendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendTensorOp> sendTensorOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendTensorOp>())
sendTensorOps.push_back(op);
for (auto sendTensorOp : sendTensorOps)
lowerChannelSendTensor(sendTensorOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
extractRowsOps.push_back(op);
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
@@ -457,7 +263,12 @@ void SpatialToPimPass::runOnOperation() {
return;
}
compactSpatialTensorGroups(funcOp, rewriter);
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
{
ConversionTarget communicationTarget(*ctx);
@@ -1,8 +1,96 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp concatOp, PatternRewriter& rewriter) const override {
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
return failure();
SmallVector<Value> packedInputs;
bool changed = false;
for (unsigned index = 0; index < concatOp.getInputs().size();) {
Value input = concatOp.getInputs()[index];
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
++endIndex;
Value packedInput = createPackedExtractSliceTensor(
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
index = endIndex;
continue;
}
}
auto result = dyn_cast<OpResult>(input);
if (!result) {
packedInputs.push_back(input);
++index;
continue;
}
Operation* owner = result.getOwner();
unsigned startIndex = result.getResultNumber();
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()) {
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
if (!nextResult || nextResult.getOwner() != owner
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
break;
++endIndex;
}
unsigned count = endIndex - index;
Value packedInput;
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
}
else {
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
}
index = endIndex;
}
if (!changed)
return failure();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
auto newConcat = pim::PimConcatOp::create(rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
tensor::EmptyOp::create(rewriter,
concatOp.getLoc(),
outputType.getShape(),
outputType.getElementType())
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
@@ -146,4 +234,23 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<PackSpatialConcatInputsPattern>(patterns.getContext());
}
void eraseUnusedTensorPackingOps(func::FuncOp funcOp, IRRewriter& rewriter) {
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
rewriter.eraseOp(op);
};
eraseUnusedOps(tensor::ConcatOp {});
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
} // namespace onnx_mlir
@@ -1,6 +1,7 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -19,5 +20,7 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir