compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled

better IR compaction after dcp merge
remove pim.mvm op
better memory report
This commit is contained in:
NiccoloN
2026-05-12 13:35:25 +02:00
parent 80a7298552
commit 628dc630a4
15 changed files with 419 additions and 305 deletions
@@ -7,7 +7,6 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
Value input = mapper.lookup(sendTensorBatchOp.getInput());
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput =
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
input = packedInput;
pim::PimSendTensorBatchOp::create(
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
@@ -21,12 +21,6 @@ def spatToPimVMM : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVM : Pat<
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b,
@@ -11,7 +11,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp);
Value input = sendTensorOp.getInput();
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
input = packedInput;
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
PimSendTensorOp::create(
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp);
}
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
rewriter.replaceOp(extractRowsOp, replacements);
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
populateTensorPackingPatterns(tensorPackingPatterns);
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
@@ -3,26 +3,6 @@
using namespace mlir;
namespace onnx_mlir {
namespace {
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
if (op.getDim() != 0)
return failure();
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
if (!packed)
return failure();
rewriter.replaceOp(op, packed);
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
@@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count
return RankedTensorType::get(packedShape, elementType.getElementType());
}
Value extractPackedChunk(
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
if (packedType && packedType == chunkType && index == 0)
return packedValue;
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
}
Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
bool coversWholeSource = packedType == inputType && startIndex == 0;
if (coversWholeSource)
return extractRowsOp.getInput();
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
if (values.empty())
return {};
@@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -3,11 +3,21 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType,
unsigned index,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp,
unsigned startIndex,
unsigned count,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir