add better createSpatCompute helper
This commit is contained in:
@@ -104,20 +104,39 @@ inline auto getTensorShape(mlir::Value tensor) {
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||||
|
|
||||||
template <typename Fn, size_t... Is>
|
template <typename Fn, size_t... Is>
|
||||||
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||||
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t>
|
||||||
|
using ValueArg = mlir::Value;
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
struct InvokeWithBlockArgsResult;
|
||||||
|
|
||||||
|
template <typename Fn, size_t... Is>
|
||||||
|
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||||
|
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Seq>
|
||||||
|
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <size_t NumInputs, typename BodyFn>
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
@@ -128,10 +147,61 @@ spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter&
|
|||||||
computeOp.getBody().push_back(block);
|
computeOp.getBody().push_back(block);
|
||||||
rewriter.setInsertionPointToStart(block);
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||||
|
auto bodyResult =
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||||
|
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return computeOp;
|
return computeOp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
|
auto* block = new mlir::Block();
|
||||||
|
for (mlir::Value input : inputs)
|
||||||
|
block->addArgument(input.getType(), loc);
|
||||||
|
|
||||||
|
computeOp.getBody().push_back(block);
|
||||||
|
rewriter.setInsertionPointToStart(block);
|
||||||
|
|
||||||
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||||
|
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||||
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
if (mlir::failed(bodyResult)) {
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
rewriter.eraseOp(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||||
|
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
return computeOp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -138,83 +139,76 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
else
|
else
|
||||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
auto im2colComputeOp =
|
constexpr size_t numInputs = 1;
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
|
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||||
|
Value paddedInput = xArg;
|
||||||
|
|
||||||
auto* im2colBlock = new Block();
|
// Pad input with zeros if needed:
|
||||||
im2colBlock->addArgument(x.getType(), loc);
|
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||||
im2colComputeOp.getBody().push_back(im2colBlock);
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
rewriter.setInsertionPointToStart(im2colBlock);
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
|
rewriter.getIndexAttr(padWidthBegin)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(padHeightEnd),
|
||||||
|
rewriter.getIndexAttr(padWidthEnd)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
paddedInput = padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value paddedInput = im2colBlock->getArgument(0);
|
// Build im2col [numPatches, patchSize]:
|
||||||
|
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||||
|
SmallVector<Value> im2colRows;
|
||||||
|
im2colRows.reserve(numPatches);
|
||||||
|
for (int64_t n = 0; n < batchSize; n++) {
|
||||||
|
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||||
|
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
rewriter.getIndexAttr(oh * strideHeight),
|
||||||
|
rewriter.getIndexAttr(ow * strideWidth)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
|
rewriter.getIndexAttr(wHeight),
|
||||||
|
rewriter.getIndexAttr(wWidth)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
// Pad input with zeros if needed:
|
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
loc,
|
||||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
rowType,
|
||||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
patch,
|
||||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
SmallVector<ReassociationIndices> {
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
{0},
|
||||||
rewriter.getIndexAttr(0),
|
{1, 2, 3}
|
||||||
rewriter.getIndexAttr(padHeightBegin),
|
});
|
||||||
rewriter.getIndexAttr(padWidthBegin)};
|
im2colRows.push_back(row);
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
}
|
||||||
rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(padHeightEnd),
|
|
||||||
rewriter.getIndexAttr(padWidthEnd)};
|
|
||||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
|
||||||
auto* padBlock = new Block();
|
|
||||||
for (int i = 0; i < 4; i++)
|
|
||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
||||||
padOp.getRegion().push_back(padBlock);
|
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
|
||||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
|
||||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
|
||||||
paddedInput = padOp.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build im2col [numPatches, patchSize]:
|
|
||||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
|
||||||
SmallVector<Value> im2colRows;
|
|
||||||
im2colRows.reserve(numPatches);
|
|
||||||
for (int64_t n = 0; n < batchSize; n++) {
|
|
||||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
|
||||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
|
||||||
rewriter.getIndexAttr(0),
|
|
||||||
rewriter.getIndexAttr(oh * strideHeight),
|
|
||||||
rewriter.getIndexAttr(ow * strideWidth)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(numChannelsIn),
|
|
||||||
rewriter.getIndexAttr(wHeight),
|
|
||||||
rewriter.getIndexAttr(wWidth)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(dilationHeight),
|
|
||||||
rewriter.getIndexAttr(dilationWidth)};
|
|
||||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
|
||||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
|
||||||
|
|
||||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
rowType,
|
|
||||||
patch,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2, 3}
|
|
||||||
});
|
|
||||||
im2colRows.push_back(row);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Concatenate all rows: [numPatches, patchSize]
|
// Concatenate all rows: [numPatches, patchSize]
|
||||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||||
|
});
|
||||||
rewriter.setInsertionPointAfter(im2colComputeOp);
|
|
||||||
|
|
||||||
// Gemm: A @ B + C = im2col @ W^T + b
|
// Gemm: A @ B + C = im2col @ W^T + b
|
||||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||||
@@ -231,30 +225,23 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
Value gemmOut = gemmOp.getY();
|
Value gemmOut = gemmOp.getY();
|
||||||
|
|
||||||
auto collectComputeOp =
|
auto collectComputeOp =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
|
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
|
||||||
|
// Restore to NCHW layout:
|
||||||
|
// [numPatches, numChannelsOut]
|
||||||
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
|
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||||
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
nhwcType,
|
||||||
|
gemmOutArg,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1, 2},
|
||||||
|
{3}
|
||||||
|
});
|
||||||
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||||
|
|
||||||
auto* collectBlock = new Block();
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
collectBlock->addArgument(gemmOut.getType(), loc);
|
});
|
||||||
collectComputeOp.getBody().push_back(collectBlock);
|
|
||||||
rewriter.setInsertionPointToStart(collectBlock);
|
|
||||||
|
|
||||||
auto gemmOutArg = collectBlock->getArguments().front();
|
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
|
||||||
// [numPatches, numChannelsOut]
|
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
|
||||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
|
||||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
nhwcType,
|
|
||||||
gemmOutArg,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1, 2},
|
|
||||||
{3}
|
|
||||||
});
|
|
||||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -155,18 +155,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
gemvOps.push_back(gemvOp.getY());
|
gemvOps.push_back(gemvOp.getY());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
auto* concatBlock = new Block();
|
});
|
||||||
for (auto gemvOp : gemvOps)
|
|
||||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
|
||||||
concatComputeOp.getBody().push_back(concatBlock);
|
|
||||||
rewriter.setInsertionPointToStart(concatBlock);
|
|
||||||
|
|
||||||
auto blockArgs = concatBlock->getArguments();
|
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
@@ -289,25 +281,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
||||||
|
SmallVector<Value> vmmOutputs;
|
||||||
|
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||||
|
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||||
|
vmmOutputs.push_back(
|
||||||
|
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||||
|
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||||
|
|
||||||
auto* computeBlock = new Block();
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
for (auto aHSlice : aHSlices[coreId])
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
|
});
|
||||||
computeOp.getBody().push_back(computeBlock);
|
|
||||||
rewriter.setInsertionPointToStart(computeBlock);
|
|
||||||
|
|
||||||
auto computeArgs = computeBlock->getArguments();
|
|
||||||
SmallVector<Value> vmmOutputs;
|
|
||||||
vmmOutputs.reserve(computeArgs.size());
|
|
||||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
|
||||||
vmmOutputs.push_back(
|
|
||||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
partialResults.push_back(computeOp.getResult(0));
|
||||||
}
|
}
|
||||||
@@ -318,34 +302,20 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto reduceComputeOp =
|
auto reduceComputeOp =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) {
|
||||||
|
SmallVector<Value> values(blockArgs.begin(), blockArgs.end());
|
||||||
auto* reduceBlock = new Block();
|
Value outHSlice = sumTensors(values, rewriter);
|
||||||
for (auto partialResult : partialResults)
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
||||||
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
|
});
|
||||||
reduceComputeOp.getBody().push_back(reduceBlock);
|
|
||||||
rewriter.setInsertionPointToStart(reduceBlock);
|
|
||||||
|
|
||||||
auto blockArgs = reduceBlock->getArguments();
|
|
||||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
|
||||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
|
||||||
|
|
||||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||||
auto* concatBlock = new Block();
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||||
for (auto outHSlice : outHSlices)
|
});
|
||||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
|
||||||
concatComputeOp.getBody().push_back(concatBlock);
|
|
||||||
rewriter.setInsertionPointToStart(concatBlock);
|
|
||||||
|
|
||||||
auto blockArgs = concatBlock->getArguments();
|
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -81,19 +82,11 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
});
|
||||||
|
|
||||||
auto* concatBlock = new Block();
|
|
||||||
for (Value gemmRow : gemmRows)
|
|
||||||
concatBlock->addArgument(gemmRow.getType(), loc);
|
|
||||||
concatComputeOp.getBody().push_back(concatBlock);
|
|
||||||
rewriter.setInsertionPointToStart(concatBlock);
|
|
||||||
|
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
|
||||||
Value gemmOut = concatComputeOp.getResult(0);
|
Value gemmOut = concatComputeOp.getResult(0);
|
||||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -154,91 +155,90 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
|
|
||||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
constexpr size_t numInputs = 1;
|
||||||
|
auto computeOp =
|
||||||
|
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||||
|
SmallVector<Value> batchResults;
|
||||||
|
batchResults.reserve(batchSize);
|
||||||
|
|
||||||
auto* computeBlock = new Block();
|
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||||
computeBlock->addArgument(xType, loc);
|
SmallVector<Value> rows;
|
||||||
computeOp.getBody().push_back(computeBlock);
|
rows.reserve(outputHeight);
|
||||||
rewriter.setInsertionPointToStart(computeBlock);
|
|
||||||
|
|
||||||
Value input = computeBlock->getArgument(0);
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||||
SmallVector<Value> batchResults;
|
SmallVector<Value> rowPixels;
|
||||||
batchResults.reserve(batchSize);
|
rowPixels.reserve(outputWidth);
|
||||||
|
|
||||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||||
SmallVector<Value> rows;
|
SmallVector<Value> outputChannelTiles;
|
||||||
rows.reserve(outputHeight);
|
outputChannelTiles.reserve(channelTileCount);
|
||||||
|
|
||||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
SmallVector<Value> rowPixels;
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
rowPixels.reserve(outputWidth);
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
|
|
||||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
SmallVector<Value> windowValues;
|
||||||
SmallVector<Value> outputChannelTiles;
|
windowValues.reserve(kernelHeight * kernelWidth);
|
||||||
outputChannelTiles.reserve(channelTileCount);
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||||
|
if (inH < 0 || inH >= inputHeight)
|
||||||
|
continue;
|
||||||
|
|
||||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
if (inW < 0 || inW >= inputWidth)
|
||||||
|
continue;
|
||||||
|
|
||||||
SmallVector<Value> windowValues;
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||||
windowValues.reserve(kernelHeight * kernelWidth);
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
rewriter.getIndexAttr(inH),
|
||||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
rewriter.getIndexAttr(inW)};
|
||||||
if (inH < 0 || inH >= inputHeight)
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
continue;
|
rewriter.getIndexAttr(tileChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(1)};
|
||||||
|
Value windowValue =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
|
||||||
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
|
windowValues.push_back(windowValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
if (windowValues.empty())
|
||||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||||
if (inW < 0 || inW >= inputWidth)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
rewriter.getIndexAttr(inH),
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||||
rewriter.getIndexAttr(inW)};
|
const int64_t divisor =
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||||
rewriter.getIndexAttr(tileChannels),
|
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||||
rewriter.getIndexAttr(1),
|
}
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
outputChannelTiles.push_back(reducedWindow);
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
Value windowValue =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
|
||||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
|
||||||
windowValues.push_back(windowValue);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (windowValues.empty())
|
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
|
||||||
|
|
||||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
|
||||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
|
||||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
||||||
const int64_t divisor =
|
|
||||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
|
||||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputChannelTiles.push_back(reducedWindow);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||||
}
|
}
|
||||||
|
|
||||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||||
}
|
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
if (failed(computeOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
rewriter.replaceOp(poolOp, computeOp->getResult(0));
|
||||||
}
|
|
||||||
|
|
||||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
|
||||||
|
|
||||||
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user