conv now lowers correctly down to bufferized pim
This commit is contained in:
@@ -6,13 +6,23 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "Conv.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
ONNXConvOpAdaptor convOpAdaptor,
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
@@ -100,76 +110,15 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const int64_t numPatches = outHeight * outWidth;
|
const int64_t numPatches = outHeight * outWidth;
|
||||||
|
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
// Pad input with zeros if needed:
|
|
||||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
|
||||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
|
||||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
|
||||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
|
||||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
|
||||||
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(padHeightBegin),
|
|
||||||
rewriter.getIndexAttr(padWidthBegin)};
|
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(padHeightEnd),
|
|
||||||
rewriter.getIndexAttr(padWidthEnd)};
|
|
||||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, x, lowPads, highPads);
|
|
||||||
auto* padBlock = new Block();
|
|
||||||
for (int i = 0; i < 4; i++)
|
|
||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
||||||
padOp.getRegion().push_back(padBlock);
|
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
|
||||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
|
||||||
x = padOp.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build im2col [numPatches, patchSize]:
|
|
||||||
// For each output position (oh, ow), extract the patch from x
|
|
||||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||||
SmallVector<Value> im2colRows;
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
im2colRows.reserve(numPatches);
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(oh * strideHeight),
|
|
||||||
rewriter.getIndexAttr(ow * strideWidth)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(numChannelsIn),
|
|
||||||
rewriter.getIndexAttr(wHeight),
|
|
||||||
rewriter.getIndexAttr(wWidth)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(dilationHeight),
|
|
||||||
rewriter.getIndexAttr(dilationWidth)};
|
|
||||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
|
||||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, x, offsets, sizes, strides);
|
|
||||||
|
|
||||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
rowType,
|
|
||||||
patch,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2, 3}
|
|
||||||
});
|
|
||||||
im2colRows.push_back(row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Concatenate all rows: [numPatches, patchSize]
|
|
||||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
// Prepare weight matrix W for crossbar storage:
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
wFlatType,
|
wFlatType,
|
||||||
@@ -196,23 +145,98 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
else
|
else
|
||||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
auto im2colComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
||||||
|
|
||||||
|
auto* im2colBlock = new Block();
|
||||||
|
im2colBlock->addArgument(x.getType(), loc);
|
||||||
|
im2colComputeOp.getBody().push_back(im2colBlock);
|
||||||
|
rewriter.setInsertionPointToStart(im2colBlock);
|
||||||
|
|
||||||
|
Value paddedInput = im2colBlock->getArgument(0);
|
||||||
|
|
||||||
|
// Pad input with zeros if needed:
|
||||||
|
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||||
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
|
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
|
rewriter.getIndexAttr(padWidthBegin)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightEnd),
|
||||||
|
rewriter.getIndexAttr(padWidthEnd)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
paddedInput = padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build im2col [numPatches, patchSize]:
|
||||||
|
// For each output position (oh, ow), extract the patch from x
|
||||||
|
SmallVector<Value> im2colRows;
|
||||||
|
im2colRows.reserve(numPatches);
|
||||||
|
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||||
|
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(oh * strideHeight),
|
||||||
|
rewriter.getIndexAttr(ow * strideWidth)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
|
rewriter.getIndexAttr(wHeight),
|
||||||
|
rewriter.getIndexAttr(wWidth)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
|
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||||
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rowType,
|
||||||
|
patch,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
im2colRows.push_back(row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concatenate all rows: [numPatches, patchSize]
|
||||||
|
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(im2colComputeOp);
|
||||||
|
|
||||||
// Gemm: A @ B + C = im2col @ W^T + b
|
// Gemm: A @ B + C = im2col @ W^T + b
|
||||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
gemmOutType,
|
gemmOutType,
|
||||||
im2col,
|
im2colComputeOp.getResult(0),
|
||||||
wTrans,
|
wTrans,
|
||||||
gemmC,
|
gemmC,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
rewriter.getBoolAttr(false),
|
rewriter.getBoolAttr(false),
|
||||||
rewriter.getBoolAttr(false));
|
rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
Value gemmOut = gemmOp.getY();
|
Value gemmOut = gemmOp.getY();
|
||||||
|
|
||||||
auto collectComputeOp =
|
auto collectComputeOp =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), gemmOut);
|
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
||||||
|
|
||||||
auto* collectBlock = new Block();
|
auto* collectBlock = new Block();
|
||||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
collectBlock->addArgument(gemmOut.getType(), loc);
|
||||||
@@ -225,7 +249,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// [numPatches, numChannelsOut]
|
// [numPatches, numChannelsOut]
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||||
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
|
||||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
nhwcType,
|
nhwcType,
|
||||||
@@ -238,7 +261,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, collectComputeOp);
|
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
struct ConvToGemm : mlir::OpConversionPattern<mlir::ONNXConvOp> {
|
|
||||||
ConvToGemm(mlir::MLIRContext* ctx)
|
|
||||||
: OpConversionPattern(ctx) {}
|
|
||||||
|
|
||||||
mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp convOp,
|
|
||||||
mlir::ONNXConvOpAdaptor convOpAdaptor,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -300,7 +300,8 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||||
|
|
||||||
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||||
|
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||||
if (!channelNewOp) {
|
if (!channelNewOp) {
|
||||||
@@ -356,7 +357,8 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Turn the channel send to pim.send
|
* Turn the channel send into a device-to-host copy into the shared
|
||||||
|
* broadcast buffer that receive ops load from later.
|
||||||
*/
|
*/
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
RewriterBase& rewriter,
|
RewriterBase& rewriter,
|
||||||
@@ -389,8 +391,18 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||||
|
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
|
rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(),
|
||||||
|
bufferAllocation.getType(),
|
||||||
|
bufferAllocation,
|
||||||
|
srcMemRef,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||||
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
@@ -15,7 +16,8 @@ namespace onnx_mlir {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool isAddressOnlyHostOp(Operation* op) {
|
static bool isAddressOnlyHostOp(Operation* op) {
|
||||||
return isa<memref::AllocOp,
|
return isa<arith::ConstantOp,
|
||||||
|
memref::AllocOp,
|
||||||
memref::GetGlobalOp,
|
memref::GetGlobalOp,
|
||||||
memref::SubViewOp,
|
memref::SubViewOp,
|
||||||
memref::CastOp,
|
memref::CastOp,
|
||||||
|
|||||||
Reference in New Issue
Block a user