compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -174,6 +175,31 @@ using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||
if (inputs.size() == 1)
|
||||
return inputs.front();
|
||||
|
||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||
int64_t concatDimSize = 0;
|
||||
bool concatDimDynamic = false;
|
||||
|
||||
for (mlir::Value input : inputs) {
|
||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||
concatDimDynamic = true;
|
||||
else
|
||||
concatDimSize += inputType.getDimSize(axis);
|
||||
}
|
||||
|
||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||
}
|
||||
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
|
||||
@@ -54,6 +54,43 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<spatial::SpatComputeBatch> batchOps;
|
||||
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
||||
|
||||
for (auto batchOp : batchOps) {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
continue;
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp.replaceAllUsesWith(computeOp.getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
@@ -124,6 +161,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
foldSingleLaneComputeBatches(*entryFunc);
|
||||
|
||||
// Count the number of compute ops and check they do not exceed the core count
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
@@ -196,8 +235,12 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
||||
loc,
|
||||
toRemoveOp.getType(),
|
||||
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
||||
ValueRange(BB->getArguments()));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
|
||||
@@ -147,161 +147,148 @@ static Value buildPackedBias(bool hasBias,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowsType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Value> rowResults;
|
||||
rowResults.reserve(packedNumRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rowResults.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||
});
|
||||
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(im2colComputeOp.getNumResults());
|
||||
for (Value result : im2colComputeOp.getResults())
|
||||
rows.push_back(result);
|
||||
return rows;
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
@@ -319,15 +306,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = gemmRowArgs.size() == 1
|
||||
? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
@@ -509,35 +493,36 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowsType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
Value gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowsType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
|
||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
@@ -553,25 +538,20 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
Value gemmC = buildPackedBias(
|
||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(gemmInputRows.size());
|
||||
for (Value gemmInputRow : gemmInputRows) {
|
||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowType,
|
||||
gemmInputRow,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmRows.push_back(gemmRow);
|
||||
}
|
||||
Value gemmRows = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowsType,
|
||||
gemmInputRows,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(gemmRows,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -65,6 +66,66 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
RankedTensorType matrixType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numRows = matrixType.getDimSize(0);
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
};
|
||||
|
||||
auto cloneBatchInputChainIntoSliceCompute =
|
||||
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
||||
Value transformedMatrix = input;
|
||||
if (!chainOps.empty()) {
|
||||
IRMapping mapper;
|
||||
mapper.map(rootValue, input);
|
||||
for (Operation* chainOp : chainOps)
|
||||
rewriter.clone(*chainOp, mapper);
|
||||
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
||||
});
|
||||
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
||||
return rowSlices;
|
||||
};
|
||||
|
||||
SmallVector<Operation*> chainOps;
|
||||
Value rootValue = matrix;
|
||||
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
||||
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(
|
||||
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
if (definingOp->getNumOperands() != 1)
|
||||
break;
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
break;
|
||||
|
||||
chainOps.push_back(definingOp);
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
@@ -156,8 +217,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
@@ -313,15 +373,116 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
auto concatComputeOp =
|
||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
||||
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
||||
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||
return failure();
|
||||
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -232,9 +232,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}));
|
||||
}
|
||||
|
||||
Value result = batchResults.size() == 1
|
||||
? batchResults.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
|
||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -100,8 +100,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
|
||||
@@ -33,9 +33,7 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||
return createSpatConcat(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
|
||||
@@ -47,8 +47,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -17,7 +18,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
|
||||
}
|
||||
if (slices.empty())
|
||||
return {};
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
}
|
||||
|
||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
return failure();
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||
}
|
||||
result = rows.size() == 1
|
||||
? rows.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
||||
}
|
||||
else {
|
||||
return failure();
|
||||
|
||||
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|
||||
@@ -7,23 +7,12 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||
assert(attr && "required precomputed channel attr is missing");
|
||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -2,16 +2,10 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -80,18 +69,4 @@ def spatToPimVSoftmax : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user