compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled

better IR compaction after dcp merge
remove pim.mvm op
better memory report
This commit is contained in:
NiccoloN
2026-05-12 13:35:25 +02:00
parent 80a7298552
commit 628dc630a4
15 changed files with 419 additions and 305 deletions
@@ -11,7 +11,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp);
Value input = sendTensorOp.getInput();
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
input = packedInput;
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
PimSendTensorOp::create(
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp);
}
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
rewriter.replaceOp(extractRowsOp, replacements);
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
populateTensorPackingPatterns(tensorPackingPatterns);
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;