add pim.vmm verifier and fix vmm lowering

reuse code for subviews
This commit is contained in:
NiccoloN
2026-05-12 15:13:50 +02:00
parent 628dc630a4
commit 4f3570520c
15 changed files with 358 additions and 207 deletions
@@ -4,9 +4,9 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <optional>
@@ -47,8 +47,8 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
static Value createPoolFillElement(
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
@@ -65,8 +65,10 @@ static Value createPoolFillElement(
llvm_unreachable("unsupported pool element type");
}
static Value createPoolFillTensor(
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
bool useMinimumValue) {
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
}
@@ -90,10 +92,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
inputType.getDimSize(3) + padLeft + padRight},
inputType.getElementType(),
inputType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padTop),
rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> lowPads = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padBottom),
@@ -104,8 +104,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
Value padValue = createPoolFillElement(
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
Value padValue =
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
tensor::YieldOp::create(rewriter, loc, padValue);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
@@ -279,7 +279,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value paddedInput =
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
@@ -307,8 +308,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow = createPoolFillTensor(
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
Value reducedWindow =
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value paddedInH = windowBaseH;
@@ -324,18 +325,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
}
SmallVector<OpFoldResult> offsets = {batchIndex,
rewriter.getIndexAttr(channelTile * xbarSize),
paddedInH,
paddedInW};
SmallVector<OpFoldResult> offsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
@@ -344,36 +341,28 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
}
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> scaleOffsets = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
rewriter.getIndexAttr(channelTile * xbarSize),
outHeightIndex,
outWidthIndex};
SmallVector<OpFoldResult> outputOffsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
}