add better createSpatCompute helper
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -138,83 +139,76 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
auto im2colComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
||||
constexpr size_t numInputs = 1;
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
auto* im2colBlock = new Block();
|
||||
im2colBlock->addArgument(x.getType(), loc);
|
||||
im2colComputeOp.getBody().push_back(im2colBlock);
|
||||
rewriter.setInsertionPointToStart(im2colBlock);
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
Value paddedInput = im2colBlock->getArgument(0);
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colComputeOp);
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
|
||||
// Gemm: A @ B + C = im2col @ W^T + b
|
||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||
@@ -231,30 +225,23 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
Value gemmOut = gemmOp.getY();
|
||||
|
||||
auto collectComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
||||
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
auto* collectBlock = new Block();
|
||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||
collectComputeOp.getBody().push_back(collectBlock);
|
||||
rewriter.setInsertionPointToStart(collectBlock);
|
||||
|
||||
auto gemmOutArg = collectBlock->getArguments().front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user