refactor Pim constant folding pass
share contiguous address resolution in PimCommon group patterns in subdir for each pass with pattern files
This commit is contained in:
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
||||
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64(*padsAttr, 0);
|
||||
padWidthBegin = getI64(*padsAttr, 1);
|
||||
padHeightEnd = getI64(*padsAttr, 2);
|
||||
padWidthEnd = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// Gemm output: [numPatches, cOut]
|
||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||
|
||||
auto elemType = xType.getElementType();
|
||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB)
|
||||
gemmC = b;
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
auto im2colComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
||||
|
||||
auto* im2colBlock = new Block();
|
||||
im2colBlock->addArgument(x.getType(), loc);
|
||||
im2colComputeOp.getBody().push_back(im2colBlock);
|
||||
rewriter.setInsertionPointToStart(im2colBlock);
|
||||
|
||||
Value paddedInput = im2colBlock->getArgument(0);
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colComputeOp);
|
||||
|
||||
// Gemm: A @ B + C = im2col @ W^T + b
|
||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2colComputeOp.getResult(0),
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
Value gemmOut = gemmOp.getY();
|
||||
|
||||
auto collectComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
||||
|
||||
auto* collectBlock = new Block();
|
||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||
collectComputeOp.getBody().push_back(collectBlock);
|
||||
rewriter.setInsertionPointToStart(collectBlock);
|
||||
|
||||
auto gemmOutArg = collectBlock->getArguments().front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
|
||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,482 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
static FailureOr<Value> materializeScaledConstantTensor(Value value,
|
||||
float factor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (factor == 1.0f)
|
||||
return value;
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> scaledValues;
|
||||
scaledValues.reserve(denseAttr.getNumElements());
|
||||
APFloat scale(factor);
|
||||
bool hadFailure = false;
|
||||
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||
APFloat scaledValue(originalValue);
|
||||
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||
hadFailure = true;
|
||||
scaledValues.push_back(std::move(scaledValue));
|
||||
}
|
||||
if (hadFailure)
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
|
||||
private:
|
||||
static Value resolveONNXExpOpFromUseChain(Value startValue);
|
||||
|
||||
static LogicalResult softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
}
|
||||
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location gemmLoc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
Value out = gemmOp.getY();
|
||||
|
||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
||||
bool transA = gemmOpAdaptor.getTransA();
|
||||
bool transB = gemmOpAdaptor.getTransB();
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(out.getType());
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
alpha = 1.0f;
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
beta = 1.0f;
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
auto outLastHSliceSize = cLastHSliceSize;
|
||||
|
||||
const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue());
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc);
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> bTiles =
|
||||
tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
SmallVector<Value> cHSlices;
|
||||
if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc);
|
||||
if (hasC)
|
||||
cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc);
|
||||
|
||||
RankedTensorType outHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize)}, outType.getElementType());
|
||||
RankedTensorType outLastHSliceType =
|
||||
RankedTensorType::get({1, static_cast<long>(bLastHSliceSize)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> outHSlices;
|
||||
outHSlices.reserve(outNumHSlices);
|
||||
for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) {
|
||||
RankedTensorType currOutHSliceType = outHSliceType;
|
||||
if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0)
|
||||
currOutHSliceType = outLastHSliceType;
|
||||
|
||||
SmallVector<Value> partialResults;
|
||||
partialResults.reserve(coresPerVSlice);
|
||||
for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) {
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(aHSlices[coreId].size());
|
||||
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
for (auto aHSlice : aHSlices[coreId])
|
||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
auto computeArgs = computeBlock->getArguments();
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(computeArgs.size());
|
||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp.getResult(0));
|
||||
}
|
||||
|
||||
if (hasC) {
|
||||
Value cHSlice = cHSlices[outSliceId];
|
||||
partialResults.push_back(cHSlice);
|
||||
}
|
||||
|
||||
auto reduceComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||
|
||||
auto* reduceBlock = new Block();
|
||||
for (auto partialResult : partialResults)
|
||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
||||
rewriter.setInsertionPointToStart(reduceBlock);
|
||||
|
||||
auto blockArgs = reduceBlock->getArguments();
|
||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value GemvToSpatialCompute::resolveONNXExpOpFromUseChain(Value startValue) {
|
||||
Value walker = startValue;
|
||||
|
||||
while (!llvm::isa<ONNXExpOp>(walker.getDefiningOp())) {
|
||||
walker = walker.getDefiningOp()->getOperand(0);
|
||||
|
||||
assert(walker && walker.getDefiningOp()
|
||||
&& "Unwinded the whole chain of operations while trying to "
|
||||
"find ONNXExpOp, but did not find it");
|
||||
}
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(walker.getDefiningOp())
|
||||
&& "Old output tile (softmax reducer) is not produced by an "
|
||||
"ONNXExpOp");
|
||||
|
||||
return walker;
|
||||
}
|
||||
|
||||
LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAndResNum>& outputOpsAndResNums,
|
||||
Value& softmaxChannel,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SpatialReducer& reducer,
|
||||
ONNXGemmOp& gemmOp,
|
||||
Location& loc) {
|
||||
// TODO: Check case with one compute op
|
||||
|
||||
// Cast vector of Value into vector of ComputeOp
|
||||
SmallVector<ComputeAndResNum> softmaxOpsToReduce =
|
||||
llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) {
|
||||
return std::make_pair(cast<spatial::SpatWeightedCompute>(computeAndResNum.first), computeAndResNum.second);
|
||||
}));
|
||||
|
||||
RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr);
|
||||
const TensorType scalarTensorType = tensorTypeBuilder;
|
||||
|
||||
reducer.applyReducePattern(
|
||||
softmaxOpsToReduce,
|
||||
[&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); },
|
||||
/* preprocess = */
|
||||
[&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); },
|
||||
[&](Value softmaxDivisor) {
|
||||
// Signal that this is the compute with the softmax divisor
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||
computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr());
|
||||
|
||||
// Broadcast the divisor to all the cores
|
||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor);
|
||||
|
||||
/*
|
||||
* softmaxDividend = onnx.exp (...)
|
||||
* sum = spat.SumOp(softmaxDividend)
|
||||
* [following can be repeated N times, thus walk the use chain]
|
||||
* softmaxDivisor = spat.sadd(sum, ...)
|
||||
*/
|
||||
Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0));
|
||||
|
||||
// Make sure the dividend is actually produced by an ONNXExpOp
|
||||
assert(llvm::isa<ONNXExpOp>(softmaxDividend.getDefiningOp())
|
||||
&& "Dividend of softmax reduction is not an ONNXExpOp");
|
||||
|
||||
// Do not divide here, divide after this
|
||||
return softmaxDivisor;
|
||||
});
|
||||
|
||||
// In all the cores, insert a ChannelRecvOp and divide the output tile by
|
||||
// the reduced denominator.
|
||||
outputOpsAndResNums.clear();
|
||||
outputOpsAndResNums.reserve(softmaxOpsToReduce.size());
|
||||
for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) {
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeToDivideOpAndResNum.first.getBody().front().getTerminator());
|
||||
|
||||
Value divisor;
|
||||
|
||||
// Check if this compute contains the softmax divisor: if so, find the
|
||||
// ChannelBroadcastSendOp, otherwise receive the value from the channel
|
||||
// using ChannelBroadcastReceiveOp
|
||||
if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) {
|
||||
|
||||
bool found = false;
|
||||
for (auto broadcastOp :
|
||||
computeToDivideOpAndResNum.first.getBody().front().getOps<spatial::SpatChannelBroadcastSendOp>()) {
|
||||
assert(found == false
|
||||
&& "More than one ChannelBroadcastSendOp in "
|
||||
"compute? How is this possible?");
|
||||
found = true;
|
||||
|
||||
divisor = broadcastOp.getData();
|
||||
}
|
||||
|
||||
assert(found
|
||||
&& "No ChannelBroadcastSendOp in compute where softmax "
|
||||
"divisor was specified to be?");
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel);
|
||||
}
|
||||
|
||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||
// needed because some some may have a different amount of `VAddOp`s due
|
||||
// to the tree reduction (e.g. some may have no VAddOp, some may have
|
||||
// multiples)
|
||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||
|
||||
outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum});
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,108 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||
return failure();
|
||||
|
||||
const int64_t batch = rhsType.getDimSize(0);
|
||||
const int64_t k = rhsType.getDimSize(1);
|
||||
const int64_t n = rhsType.getDimSize(2);
|
||||
const int64_t m = lhsType.getDimSize(0);
|
||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||
|| outType.getDimSize(2) != n)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||
|
||||
Value lhsTransposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(batch * n);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmRowType,
|
||||
rhsRow,
|
||||
lhsTransposed,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
gemmRows.push_back(gemmOp.getY());
|
||||
}
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (Value gemmRow : gemmRows)
|
||||
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,427 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
// Simple case: if we have only one input, just return it
|
||||
if (valuesToReduce.size() == 1)
|
||||
return valuesToReduce[0];
|
||||
|
||||
if (preprocess) {
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = preprocess(valToReduce);
|
||||
}
|
||||
}
|
||||
|
||||
// It is possible that `valuesToReduce` contains two entries for the same
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
// if (valToReduce.getDefiningOp()) {
|
||||
// // If the value is defined by an operation, we take the parent
|
||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
||||
// } else {
|
||||
// // Otherwise it is a block argument,
|
||||
// computeOp->getBlock()->getParentOp();
|
||||
// }
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp);
|
||||
|
||||
if (it != lastValueForCompute.end()) {
|
||||
// If we have already seen this computeOp, apply the reduction
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
// Now, reconstruct from the map the valuesToReduce list
|
||||
valuesToReduce.clear();
|
||||
valuesToReduce.reserve(lastValueForCompute.size());
|
||||
for (auto& entry : lastValueForCompute)
|
||||
valuesToReduce.push_back(entry.second);
|
||||
|
||||
Location loc = valuesToReduce[0].getLoc();
|
||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
||||
|
||||
// Recursive algorithm to reduce the inputs to a single one:
|
||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
||||
// the valuesToReduce list which becomes half the size.
|
||||
// - Repeat until there is only one input left.
|
||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
||||
while (valuesToReduceRef.size() > 1) {
|
||||
SmallVector<Value> nextValuesToReduce;
|
||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
||||
auto firstValue = valuesToReduceRef[i];
|
||||
auto secondValue = valuesToReduceRef[i + 1];
|
||||
|
||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
||||
|
||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
||||
std::swap(firstValue, secondValue);
|
||||
std::swap(firstCompute, secondCompute);
|
||||
}
|
||||
|
||||
// 1. Add a channel before the first computeOp
|
||||
rewriter.setInsertionPoint(firstCompute);
|
||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Add a sendOp after the first value
|
||||
rewriter.setInsertionPointAfterValue(firstValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
Value reduced = reduce(receivedValue, secondValue);
|
||||
|
||||
nextValuesToReduce.push_back(reduced);
|
||||
}
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (valuesToReduceRef.size() % 2 == 1)
|
||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
}
|
||||
|
||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalValue = valuesToReduceRef[0];
|
||||
|
||||
if (postprocess) {
|
||||
rewriter.setInsertionPointAfterValue(finalValue);
|
||||
finalValue = postprocess(finalValue);
|
||||
}
|
||||
|
||||
return finalValue;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
bool hasPostProcessPoolingWindow() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||
loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
PoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = getImageHeight(xShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t output_h = getImageHeight(yShape);
|
||||
size_t output_w = getImageWidth(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
||||
// Suppose that the input tensor is produced by concatenating the results of
|
||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt =
|
||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
||||
|
||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
||||
// can do better.
|
||||
// If some input tiles come from the func.func operands, load
|
||||
// them into a computeOp and yield them
|
||||
for (size_t t = 0; t < channelTileCount; t++) {
|
||||
for (size_t x = 0; x < input_w; x++) {
|
||||
for (size_t y = 0; y < input_h; y++) {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||
tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2: Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[channelTile][x][y]
|
||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
|
||||
// List of values to pool for each output pixel
|
||||
SmallVector<Value> valuesToPool;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
|
||||
// Each output pixel tile is computed by pooling a window of input
|
||||
// pixel tiles
|
||||
valuesToPool.clear();
|
||||
size_t tilesSkippedByPadding = 0;
|
||||
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
|
||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
tilesSkippedByPadding++;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value inputTile = inputTiles[outTile][inX][inY];
|
||||
|
||||
Value valueToPool;
|
||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
|
||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
||||
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
||||
}
|
||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
if (failed(sendOpOpt)) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
}
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
valueToPool = sendOp.getData();
|
||||
}
|
||||
else {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
}
|
||||
|
||||
valuesToPool.push_back(valueToPool);
|
||||
}
|
||||
}
|
||||
|
||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
||||
// assert(computeOpsForPooling.size() != 1 &&
|
||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
||||
// this " "should have been simplified earlier???");
|
||||
|
||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
||||
postProcessFn = [&](const Value prevFinalRes) {
|
||||
return postProcessPoolingWindow(
|
||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
};
|
||||
}
|
||||
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
// Send this value through a channel, and receive it in the
|
||||
// `func.func`. During lowering, we will need to "move it" into the
|
||||
// users computeOps
|
||||
auto computeOpOfReduced =
|
||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel =
|
||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
||||
// them!
|
||||
|
||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
auto outputTile = outputTiles[outTile][outX][outY];
|
||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
||||
if (!outputTileProducer) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
}
|
||||
|
||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
|
||||
rewriter.replaceOp(poolOp, outputImage);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
||||
ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,89 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
ReduceMeanConversionPattern(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// Get the input tensor.
|
||||
Value inputTensor = adaptor.getData();
|
||||
auto inputTensorType = cast<RankedTensorType>(inputTensor.getType());
|
||||
|
||||
// This pattern will substitute the ONNXReduceMeanV13Op with a
|
||||
// ONNXAveragePoolOp with the same input tensor and an appropriate kernel
|
||||
// shape and strides.
|
||||
|
||||
// To get the stride and shape of the kernel, we need to read the tensor
|
||||
// shape.
|
||||
int image_height = inputTensorType.getShape()[2];
|
||||
int image_width = inputTensorType.getShape()[3];
|
||||
|
||||
// Define the kernel shape and strides.
|
||||
SmallVector<int64_t> kernelShapeVals = {image_height, image_width};
|
||||
SmallVector<int64_t> stridesVals = {image_height, image_width};
|
||||
SmallVector<int64_t> dilationsVals = {1, 1};
|
||||
|
||||
// Set the pads to 0.
|
||||
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
|
||||
|
||||
// Create the ArrayAttrs
|
||||
auto kernelShape = mlir::ArrayAttr::get(
|
||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto dilations = mlir::ArrayAttr::get(
|
||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
// Create the resulting tensor type.
|
||||
auto resultType = RankedTensorType::get(
|
||||
/*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
|
||||
/*elementType=*/inputTensorType.getElementType());
|
||||
|
||||
// Create the ONNXAveragePoolOp.
|
||||
auto averagePool = ONNXAveragePoolOp::create(rewriter,
|
||||
reduceMean.getLoc(),
|
||||
resultType,
|
||||
inputTensor,
|
||||
/*auto_pad=*/"NOTSET",
|
||||
/*ceil_mode=*/0,
|
||||
/*count_include_pad=*/1,
|
||||
dilations,
|
||||
/*kernel_shape=*/kernelShape,
|
||||
/*pads=*/pads,
|
||||
/*strides=*/strides);
|
||||
|
||||
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
|
||||
rewriter.replaceOp(reduceMean, averagePool.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ReduceMeanConversionPattern>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,31 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
||||
ONNXConcatToTensorConcat(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||
ONNXConcatOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConcatToTensorConcat>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(sourceIdx);
|
||||
while (sourceProduct != resultProduct) {
|
||||
if (sourceProduct > resultProduct)
|
||||
return false;
|
||||
sourceIdx++;
|
||||
if (sourceIdx >= sourceShape.size())
|
||||
return false;
|
||||
group.push_back(sourceIdx);
|
||||
sourceProduct *= sourceShape[sourceIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(resultIdx);
|
||||
while (resultProduct != sourceProduct) {
|
||||
if (resultProduct > sourceProduct)
|
||||
return false;
|
||||
resultIdx++;
|
||||
if (resultIdx >= resultShape.size())
|
||||
return false;
|
||||
group.push_back(resultIdx);
|
||||
resultProduct *= resultShape[resultIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||
ONNXReshapeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType) {
|
||||
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,35 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename OpTy, typename OpAdaptorTy>
|
||||
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
|
||||
RemoveUnusedHelperOps(MLIRContext* ctx)
|
||||
: OpRewritePattern<OpTy>(ctx) {}
|
||||
|
||||
void initialize() { this->setHasBoundedRewriteRecursion(); }
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
|
||||
if (op.getResult().use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
|
||||
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
|
||||
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user