Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 568529ea5f | |||
| ca2e1645bb | |||
| 6933804003 |
@@ -1,2 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
|
.claude
|
||||||
|
AGENTS.md
|
||||||
build
|
build
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ add_onnx_mlir_library(OMPIMAccel
|
|||||||
Pass/CountInstructionPass.cpp
|
Pass/CountInstructionPass.cpp
|
||||||
Pass/EmitPimJsonPass.cpp
|
Pass/EmitPimJsonPass.cpp
|
||||||
Pass/MessagePass.cpp
|
Pass/MessagePass.cpp
|
||||||
Pass/PimFoldHostConstantsPass.cpp
|
Pass/PimConstantFoldingPass.cpp
|
||||||
Pass/PimHostVerificationPass.cpp
|
Pass/PimHostVerificationPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|||||||
@@ -12,7 +12,15 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
|
std::string getOutputDir() {
|
||||||
|
if (outputBaseName.empty() || outputBaseName == "-")
|
||||||
|
return {};
|
||||||
|
|
||||||
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||||
|
if (lastSlash == std::string::npos)
|
||||||
|
return ".";
|
||||||
|
return outputBaseName.substr(0, lastSlash);
|
||||||
|
}
|
||||||
|
|
||||||
void createDirectory(const std::string& directory) {
|
void createDirectory(const std::string& directory) {
|
||||||
std::error_code errorCode;
|
std::error_code errorCode;
|
||||||
@@ -21,7 +29,11 @@ void createDirectory(const std::string& directory) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||||
std::string dialectsDir = getOutputDir() + "/dialects";
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/dialects";
|
||||||
createDirectory(dialectsDir);
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
@@ -143,4 +155,85 @@ FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, Rewr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||||
|
SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
indices[dim] = linearIndex / stride;
|
||||||
|
linearIndex %= stride;
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||||
|
linearIndex += index * stride;
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
numElements *= dim;
|
||||||
|
return numElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||||
|
ArrayRef<int64_t> offsets,
|
||||||
|
ArrayRef<int64_t> sizes,
|
||||||
|
ArrayRef<int64_t> strides) {
|
||||||
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||||
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstNonZeroOffset = std::find_if(
|
||||||
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return offset != 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||||
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||||
|
if (size > dimension - offset)
|
||||||
|
return false;
|
||||||
|
++firstNonZeroOffset;
|
||||||
|
|
||||||
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != sizesAndShape.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
|
||||||
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
@@ -32,4 +34,18 @@ mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::m
|
|||||||
llvm::FailureOr<mlir::Operation*>
|
llvm::FailureOr<mlir::Operation*>
|
||||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -382,12 +381,9 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
|||||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||||
size_t totalElements = srcType.getNumElements();
|
size_t totalElements = srcType.getNumElements();
|
||||||
|
|
||||||
// Read permutation and compute its inverse
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||||
SmallVector<int64_t> perm =
|
SmallVector<int64_t> perm =
|
||||||
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||||
SmallVector<int64_t> permInv(rank);
|
|
||||||
for (size_t i = 0; i < rank; i++)
|
|
||||||
permInv[perm[i]] = i;
|
|
||||||
|
|
||||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||||
SmallVector<int64_t> dstShape(rank);
|
SmallVector<int64_t> dstShape(rank);
|
||||||
@@ -412,10 +408,10 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
|||||||
remaining %= srcStrides[d];
|
remaining %= srcStrides[d];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
|
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
|
||||||
size_t dstFlat = 0;
|
size_t dstFlat = 0;
|
||||||
for (size_t d = 0; d < rank; d++)
|
for (size_t d = 0; d < rank; d++)
|
||||||
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
|
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||||
|
|
||||||
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimFoldHostConstantsPass());
|
pm.addPass(createPimConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
pm.addPass(createMessagePass("Pim constants folded"));
|
||||||
pm.addPass(createPimHostVerificationPass());
|
pm.addPass(createPimHostVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim host verified"));
|
pm.addPass(createMessagePass("Pim host verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||||
assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1);
|
|
||||||
|
|
||||||
// We need to understand what is group
|
// We need to understand what is group
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
const int64_t xHeight = xType.getDimSize(2);
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
const int64_t xWidth = xType.getDimSize(3);
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
@@ -107,7 +107,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
// Gemm output: [numPatches, cOut]
|
// Gemm output: [numPatches, cOut]
|
||||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
const int64_t numPatches = outHeight * outWidth;
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
@@ -115,7 +116,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
// Prepare weight matrix W for crossbar storage:
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
@@ -160,7 +161,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(padHeightBegin),
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
@@ -182,36 +183,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build im2col [numPatches, patchSize]:
|
// Build im2col [numPatches, patchSize]:
|
||||||
// For each output position (oh, ow), extract the patch from x
|
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||||
SmallVector<Value> im2colRows;
|
SmallVector<Value> im2colRows;
|
||||||
im2colRows.reserve(numPatches);
|
im2colRows.reserve(numPatches);
|
||||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
for (int64_t n = 0; n < batchSize; n++) {
|
||||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
|
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||||
rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||||
rewriter.getIndexAttr(oh * strideHeight),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(ow * strideWidth)};
|
rewriter.getIndexAttr(oh * strideHeight),
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(ow * strideWidth)};
|
||||||
rewriter.getIndexAttr(numChannelsIn),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(wHeight),
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
rewriter.getIndexAttr(wWidth)};
|
rewriter.getIndexAttr(wHeight),
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(wWidth)};
|
||||||
rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(dilationHeight),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(dilationWidth)};
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rowType,
|
rowType,
|
||||||
patch,
|
patch,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
{1, 2, 3}
|
{1, 2, 3}
|
||||||
});
|
});
|
||||||
im2colRows.push_back(row);
|
im2colRows.push_back(row);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
|||||||
return returnValue;
|
return returnValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
assert(!users.empty());
|
assert(!users.empty());
|
||||||
@@ -67,23 +67,24 @@ Operation* getEarliestUserWithinBlock(Value value) {
|
|||||||
return earliestUser;
|
return earliestUser;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> getOpOperandsSortedByUses(Operation* operation) {
|
SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||||
auto operandsAndUses = map_to_vector(operation->getOperands(), [](Value operand) -> std::pair<Value, size_t> {
|
auto operandsAndUses =
|
||||||
|
map_to_vector(operation->getOperands(), [](mlir::Value operand) -> std::pair<mlir::Value, size_t> {
|
||||||
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
return {operand, std::distance(operand.use_begin(), operand.use_end())};
|
||||||
});
|
});
|
||||||
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
sort(operandsAndUses, [](auto a, auto b) { return a.second < b.second; });
|
||||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
auto resultType = result.getType();
|
auto resultType = result.getType();
|
||||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||||
|
|
||||||
SmallVector<Value> operands = getOpOperandsSortedByUses(operation);
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||||
auto validOperands =
|
auto validOperands =
|
||||||
make_filter_range(operands, [resultType](Value operand) { return operand.getType() == resultType; });
|
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
||||||
auto bestOperand = validOperands.begin();
|
auto bestOperand = validOperands.begin();
|
||||||
|
|
||||||
if (bestOperand != validOperands.end())
|
if (bestOperand != validOperands.end())
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -39,71 +40,13 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
|
|||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
static bool isMemoryContiguous(const mlir::ArrayRef<int64_t> srcShape,
|
|
||||||
const mlir::ArrayRef<int64_t> offsets,
|
|
||||||
const mlir::ArrayRef<int64_t> sizes,
|
|
||||||
const mlir::ArrayRef<int64_t> strides) {
|
|
||||||
// Check that all strides are 1
|
|
||||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
// Check offsets from right to left:
|
|
||||||
// The first offset_n at position n different from 0:
|
|
||||||
// - limits all sizes to the left to 1
|
|
||||||
// - limits size_n to dimension_n - offset_n
|
|
||||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
||||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstNonZeroOffset = std::find_if(
|
|
||||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return offset != 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
||||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
||||||
if (size > dimension - offset)
|
|
||||||
return false;
|
|
||||||
++firstNonZeroOffset;
|
|
||||||
|
|
||||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check sizes from right to left:
|
|
||||||
// The first size_n at position n different from shape_n limits all sizes to the left to 1
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstDifferentSize != sizesAndShape.end()) {
|
|
||||||
++firstDifferentSize;
|
|
||||||
|
|
||||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, _] = sizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline mlir::tensor::EmptyOp
|
inline mlir::tensor::EmptyOp
|
||||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool isAConcatOp(mlir::Operation* op) {
|
inline bool isAConcatOp(mlir::Operation* op) {
|
||||||
return isa<mlir::tensor::ConcatOp>(op) || isa<spatial::SpatImgConcatOp>(op);
|
return llvm::isa<mlir::tensor::ConcatOp>(op) || llvm::isa<spatial::SpatImgConcatOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/IR/BuiltinDialect.h"
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
@@ -52,20 +53,21 @@ private:
|
|||||||
|
|
||||||
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||||
void addReceiveOps(Value& channelSourceOp,
|
void addReceiveOps(Value channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
bool useBroadcastOp,
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter);
|
IRRewriter& rewriter);
|
||||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter);
|
IRRewriter& rewriter);
|
||||||
|
void markOpToRemove(Operation* op);
|
||||||
|
|
||||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
@@ -76,6 +78,34 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static bool isChannelUseChainOp(Operation* op) {
|
||||||
|
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||||
|
op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t countComputeLeafUsers(Value value) {
|
||||||
|
size_t leafUserCount = 0;
|
||||||
|
|
||||||
|
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
|
Operation* owner = use.getOwner();
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||||
|
leafUserCount++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isChannelUseChainOp(owner))
|
||||||
|
llvm_unreachable("Channel use chain contains unsupported op");
|
||||||
|
|
||||||
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
|
self(owner->getResult(0), self);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
walkUses(value, walkUses);
|
||||||
|
return leafUserCount;
|
||||||
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 1;
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
@@ -103,7 +133,10 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
|
||||||
addResultBuffer(returnOp, rewriter);
|
addResultBuffer(returnOp, rewriter);
|
||||||
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||||
operationsToRemove.push_back(receiveOp);
|
operationsToRemove.push_back(receiveOp);
|
||||||
@@ -129,7 +162,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||||
@@ -233,14 +266,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||||
|
|
||||||
// 2. Receive value through the channel
|
// 2. Receive value through the channel. Broadcast is needed whenever the
|
||||||
// If this result is used by more than one user, then use a "Broadcast"
|
// value eventually reaches more than one compute consumer, even through a
|
||||||
// channel operation. However, there is a special case: we have a single
|
// chain of view-like ops.
|
||||||
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
|
bool useBroadcastOp = countComputeLeafUsers(result) > 1;
|
||||||
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
|
addReceiveOps(result, channelOp, useBroadcastOp, rewriter);
|
||||||
// will detect this case and update `useBroadcastOp` accordingly.
|
|
||||||
bool useBroadcastOp = (numResultUses > 1);
|
|
||||||
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
// 3. Send the value through the channel
|
// 3. Send the value through the channel
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
@@ -327,7 +357,7 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||||
@@ -359,7 +389,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||||
|
|
||||||
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
|
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||||
|
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||||
|
|
||||||
Block& block = funcOp.getBody().front();
|
Block& block = funcOp.getBody().front();
|
||||||
@@ -369,7 +400,8 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
inputTensors.push_back(toTensorOp);
|
inputTensors.push_back(toTensorOp);
|
||||||
|
|
||||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||||
funcOp.eraseArgument(i);
|
if (failed(funcOp.eraseArgument(i)))
|
||||||
|
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||||
@@ -383,6 +415,9 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||||
|
|
||||||
|
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||||
|
continue;
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||||
@@ -416,12 +451,15 @@ void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
for (auto sliceOp : sliceOpsToRemove)
|
for (auto sliceOp : sliceOpsToRemove)
|
||||||
if (sliceOp->getUses().empty())
|
if (sliceOp->getUses().empty())
|
||||||
rewriter.eraseOp(sliceOp);
|
rewriter.eraseOp(sliceOp);
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
|
Value channelSourceOp,
|
||||||
|
Value consumerValue,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
|
||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto& computeBlock = computeOp.getRegion().front();
|
auto& computeBlock = computeOp.getRegion().front();
|
||||||
@@ -434,68 +472,68 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||||
Value receivedValue;
|
Value receivedValue;
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
receivedValue =
|
||||||
|
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
else
|
else
|
||||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
|
|
||||||
blockArg.replaceAllUsesWith(receivedValue);
|
Value replacementValue = receivedValue;
|
||||||
|
if (consumerValue != channelSourceOp) {
|
||||||
|
SmallVector<Operation*> clonedChain;
|
||||||
|
Value currentValue = consumerValue;
|
||||||
|
while (currentValue != channelSourceOp) {
|
||||||
|
Operation* definingOp = currentValue.getDefiningOp();
|
||||||
|
if (!definingOp || !isChannelUseChainOp(definingOp))
|
||||||
|
llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute");
|
||||||
|
|
||||||
|
clonedChain.push_back(definingOp);
|
||||||
|
currentValue = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(channelSourceOp, receivedValue);
|
||||||
|
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
markOpToRemove(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
||||||
|
blockArg.replaceAllUsesWith(replacementValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
|
void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
bool useBroadcastOp,
|
||||||
bool& useBroadcastOp,
|
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
auto sourceOpUses = channelSourceOp.getUses();
|
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||||
|
for (OpOperand& use : currentValue.getUses()) {
|
||||||
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
|
Operation* owner = use.getOwner();
|
||||||
if (useBroadcastOp == false) {
|
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||||
// if useBroadcastOp is false, then sourceOp must have only one user
|
|
||||||
assert(rangeLength(sourceOpUses) == 1);
|
|
||||||
|
|
||||||
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
|
|
||||||
auto reshapeOpUses = reshapeOp.getOutput().getUses();
|
|
||||||
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
|
|
||||||
if (reshapeOpUsesCount > 1)
|
|
||||||
useBroadcastOp = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& resultUse : sourceOpUses) {
|
|
||||||
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
|
|
||||||
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
|
|
||||||
|
|
||||||
if (computeUser) {
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
|
||||||
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!computeUser) {
|
|
||||||
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
|
|
||||||
if (!reshapeOp) {
|
|
||||||
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
|
|
||||||
resultUse.getOwner()->dump();
|
|
||||||
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
|
|
||||||
}
|
|
||||||
|
|
||||||
// The tensorType now becomes the one of the reshapeOp
|
|
||||||
channelTensorType = reshapeOp.getResult().getType();
|
|
||||||
|
|
||||||
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
|
|
||||||
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
|
|
||||||
|
|
||||||
if (!computeUser)
|
|
||||||
llvm_unreachable("ReshapeOp users must be ComputeOps");
|
|
||||||
|
|
||||||
replaceBlockArgumentWithRecvOp(
|
replaceBlockArgumentWithRecvOp(
|
||||||
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
|
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the reshapeOp, so that the sourceOp has no users
|
if (!isChannelUseChainOp(owner))
|
||||||
operationsToRemove.push_back(reshapeOp);
|
llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op");
|
||||||
|
|
||||||
|
markOpToRemove(owner);
|
||||||
|
assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result");
|
||||||
|
self(owner->getResult(0), self);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||||
|
if (!llvm::is_contained(operationsToRemove, op))
|
||||||
|
operationsToRemove.push_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
@@ -527,15 +565,10 @@ void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
|
|
||||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||||
|
|
||||||
auto tensorType = receiveOp.getType();
|
|
||||||
Value receiveRes = receiveOp.getResult();
|
Value receiveRes = receiveOp.getResult();
|
||||||
|
|
||||||
// Check if the receiveOp value has more than one user
|
bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1;
|
||||||
auto receiveUses = receiveRes.getUses();
|
addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter);
|
||||||
auto receiveUsesCount = rangeLength(receiveUses);
|
|
||||||
assert(receiveUsesCount > 0);
|
|
||||||
bool useBroadcastOp = receiveUsesCount > 1;
|
|
||||||
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
|
|
||||||
|
|
||||||
if (useBroadcastOp) {
|
if (useBroadcastOp) {
|
||||||
// When receiving, we actually noticed that the value has more than one
|
// When receiving, we actually noticed that the value has more than one
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim_buf");
|
dumpModule(moduleOp, "pim1_buff");
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ add_onnx_mlir_library(SpatialOps
|
|||||||
Transforms/SpatialBufferizableOpInterface.cpp
|
Transforms/SpatialBufferizableOpInterface.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
OMONNXIncGen
|
||||||
OMSpatialIncGen
|
OMSpatialIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
OMMlirDialects
|
OMMlirDialects
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,618 @@
|
|||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static Value stripMemRefCasts(Value value) {
|
||||||
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
|
value = castOp.getSource();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||||
|
Location loc,
|
||||||
|
MemRefType globalType,
|
||||||
|
DenseElementsAttr denseAttr,
|
||||||
|
StringRef nameStem,
|
||||||
|
IntegerAttr alignment = {}) {
|
||||||
|
auto globalName = nameStem.str();
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (moduleOp.lookupSymbol(globalName))
|
||||||
|
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||||
|
|
||||||
|
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||||
|
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||||
|
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
return memref::GlobalOp::create(moduleBuilder,
|
||||||
|
loc,
|
||||||
|
globalName,
|
||||||
|
visibility,
|
||||||
|
globalType,
|
||||||
|
denseAttr,
|
||||||
|
/*constant=*/true,
|
||||||
|
alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||||
|
value = stripMemRefCasts(value);
|
||||||
|
|
||||||
|
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
return denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!tensorType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
if (static_cast<int64_t>(perms.size()) != rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
llvm::SmallBitVector seen(rank);
|
||||||
|
SmallVector<int64_t> transposedShape;
|
||||||
|
transposedShape.reserve(rank);
|
||||||
|
for (int64_t perm : perms) {
|
||||||
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||||
|
return failure();
|
||||||
|
seen.set(perm);
|
||||||
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||||
|
|
||||||
|
SmallVector<int64_t> originalStrides(rank, 1);
|
||||||
|
SmallVector<int64_t> transposedStrides(rank, 1);
|
||||||
|
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
||||||
|
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
||||||
|
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> originalIndices(rank);
|
||||||
|
SmallVector<int64_t> transposedIndices(rank);
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||||
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
originalIndices[dim] = remaining / originalStrides[dim];
|
||||||
|
remaining %= originalStrides[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedIndices[dim] = originalIndices[perms[dim]];
|
||||||
|
|
||||||
|
int64_t transposedLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
||||||
|
|
||||||
|
transposedValues[transposedLinearIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConstantSubviewCopy {
|
||||||
|
DenseElementsAttr source;
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
Operation* copyOp = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||||
|
if (!mapOp.getInputs().empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
|
||||||
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Attribute attr;
|
||||||
|
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
|
||||||
|
return failure();
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
||||||
|
if (!coreOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||||
|
if (!initType || !initType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto fillValue = getConstantMapYield(mapOp);
|
||||||
|
if (failed(fillValue))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
|
||||||
|
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
|
||||||
|
|
||||||
|
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(coreOp);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
|
|
||||||
|
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
|
||||||
|
if (elementByteWidth == 0)
|
||||||
|
return failure();
|
||||||
|
size_t totalBytes = initType.getNumElements() * elementByteWidth;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(mapOp);
|
||||||
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
mapOp.getLoc(),
|
||||||
|
initType,
|
||||||
|
mapOp.getInit(),
|
||||||
|
getGlobalOp.getResult(),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||||
|
rewriter.eraseOp(mapOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StaticSubviewInfo {
|
||||||
|
Value source;
|
||||||
|
SmallVector<int64_t> sourceShape;
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
StaticSubviewInfo info;
|
||||||
|
info.source = source;
|
||||||
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
info.offsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto staticSize = getConstantIntValue(size);
|
||||||
|
if (!staticSize)
|
||||||
|
return failure();
|
||||||
|
info.sizes.push_back(*staticSize);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
info.strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t
|
||||||
|
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(info.sourceShape.size());
|
||||||
|
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||||
|
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||||
|
sourceIndices.push_back(info.offsets.back());
|
||||||
|
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||||
|
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
||||||
|
const bool splitSrc = succeeded(srcSubview)
|
||||||
|
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||||
|
const bool splitDst = succeeded(dstSubview)
|
||||||
|
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||||
|
if (!splitSrc && !splitDst)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
||||||
|
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
||||||
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != dstType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||||
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||||
|
if (elementByteWidth <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
|
if (copyOp.getSize() != totalBytes)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||||
|
if (sliceBytes <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||||
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(copyOp);
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||||
|
SmallVector<int64_t> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||||
|
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
||||||
|
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
const int64_t dstByteOffset = copyOp.getDstOffset()
|
||||||
|
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||||
|
splitDst ? dstSubview->source : copyOp.getDst(),
|
||||||
|
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||||
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
const int64_t numElements = resultTensorType.getNumElements();
|
||||||
|
if (numElements < 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Attribute fillValue;
|
||||||
|
SmallVector<ConstantSubviewCopy> copies;
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
|
||||||
|
SmallVector<Value> pendingAliases;
|
||||||
|
pendingAliases.push_back(allocOp.getResult());
|
||||||
|
|
||||||
|
while (!pendingAliases.empty()) {
|
||||||
|
Value alias = pendingAliases.pop_back_val();
|
||||||
|
for (Operation* user : alias.getUsers()) {
|
||||||
|
if (!visitedAliases.insert(user).second)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
|
||||||
|
if (mapOp.getInit() != alias)
|
||||||
|
return failure();
|
||||||
|
auto maybeFillValue = getConstantMapYield(mapOp);
|
||||||
|
if (failed(maybeFillValue))
|
||||||
|
return failure();
|
||||||
|
if (fillValue && fillValue != *maybeFillValue)
|
||||||
|
return failure();
|
||||||
|
fillValue = *maybeFillValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
strides.reserve(subviewOp.getMixedStrides().size());
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
offsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* subviewUser : subviewOp->getUsers()) {
|
||||||
|
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
|
||||||
|
if (copyOp.getTarget() != subviewOp.getResult())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
copies.push_back({*denseAttr, offsets, strides, copyOp});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||||
|
pendingAliases.push_back(castOp.getResult());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fillValue)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Attribute> resultValues(numElements, fillValue);
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
|
||||||
|
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
|
||||||
|
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const ConstantSubviewCopy& copy : copies) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|
||||||
|
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
|
||||||
|
SmallVector<int64_t> sourceIndices =
|
||||||
|
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
|
||||||
|
SmallVector<int64_t> resultIndices;
|
||||||
|
resultIndices.reserve(sourceIndices.size());
|
||||||
|
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
|
||||||
|
resultIndices.push_back(offset + sourceIndex * stride);
|
||||||
|
|
||||||
|
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
|
||||||
|
resultValues[resultLinearIndex] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!sourceGetGlobal)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
||||||
|
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> perms;
|
||||||
|
perms.reserve(transposeOp.getPerms().size());
|
||||||
|
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
||||||
|
perms.push_back(attr.getInt());
|
||||||
|
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||||
|
if (failed(transposedAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
||||||
|
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
MemRefType globalType = resultType;
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||||
|
transposeOp.getLoc(),
|
||||||
|
globalType,
|
||||||
|
*transposedAttr,
|
||||||
|
sourceGlobal.getName().str() + "__folded_transpose",
|
||||||
|
sourceGlobal.getAlignmentAttr());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(transposeOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
||||||
|
|
||||||
|
bool isAlwaysWeight =
|
||||||
|
!transposeOp->getUsers().empty()
|
||||||
|
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
|
if (isAlwaysWeight) {
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
|
||||||
|
if (failed(foldedAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto allocType = cast<MemRefType>(allocOp.getType());
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(allocOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||||
|
|
||||||
|
SmallVector<Operation*> opsToErase;
|
||||||
|
SmallVector<memref::CastOp> castsToReplace;
|
||||||
|
bool allLiveUsersAreCoreOps = true;
|
||||||
|
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
|
||||||
|
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
|
||||||
|
opsToErase.push_back(user);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||||
|
castsToReplace.push_back(castOp);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isa<pim::PimCoreOp>(user))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||||
|
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
|
})) {
|
||||||
|
allLiveUsersAreCoreOps = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||||
|
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
||||||
|
})) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allLiveUsersAreCoreOps) {
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
|
||||||
|
for (memref::CastOp castOp : castsToReplace)
|
||||||
|
preservedUsers.insert(castOp);
|
||||||
|
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
|
||||||
|
|
||||||
|
for (memref::CastOp castOp : castsToReplace) {
|
||||||
|
rewriter.setInsertionPoint(castOp);
|
||||||
|
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
|
||||||
|
rewriter.replaceOp(castOp, replacementCast);
|
||||||
|
if (allLiveUsersAreCoreOps)
|
||||||
|
markWeightAlways(replacementCast.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||||
|
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
|
||||||
|
rewriter.eraseOp(subviewUser);
|
||||||
|
if (op->use_empty())
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allocOp.use_empty())
|
||||||
|
rewriter.eraseOp(allocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||||
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
|
|
||||||
|
LogicalResult initialize(MLIRContext* context) override {
|
||||||
|
RewritePatternSet owningPatterns(context);
|
||||||
|
for (auto* dialect : context->getLoadedDialects())
|
||||||
|
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||||
|
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||||
|
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||||
|
owningPatterns
|
||||||
|
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
||||||
|
context);
|
||||||
|
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
config.enableFolding();
|
||||||
|
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dumpModule(getOperation(), "pim2_folded");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,181 +0,0 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
||||||
if (!tensorType)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
int64_t rank = tensorType.getRank();
|
|
||||||
if (static_cast<int64_t>(perms.size()) != rank)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
llvm::SmallBitVector seen(rank);
|
|
||||||
SmallVector<int64_t> transposedShape;
|
|
||||||
transposedShape.reserve(rank);
|
|
||||||
for (int64_t perm : perms) {
|
|
||||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
|
||||||
return failure();
|
|
||||||
seen.set(perm);
|
|
||||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
|
||||||
if (denseAttr.isSplat())
|
|
||||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
|
||||||
|
|
||||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
|
||||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
|
||||||
|
|
||||||
SmallVector<int64_t> originalStrides(rank, 1);
|
|
||||||
SmallVector<int64_t> transposedStrides(rank, 1);
|
|
||||||
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
|
||||||
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
|
||||||
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> originalIndices(rank);
|
|
||||||
SmallVector<int64_t> transposedIndices(rank);
|
|
||||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
|
||||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
||||||
originalIndices[dim] = remaining / originalStrides[dim];
|
|
||||||
remaining %= originalStrides[dim];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim)
|
|
||||||
transposedIndices[dim] = originalIndices[perms[dim]];
|
|
||||||
|
|
||||||
int64_t transposedLinearIndex = 0;
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim)
|
|
||||||
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
|
||||||
|
|
||||||
transposedValues[transposedLinearIndex] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
|
||||||
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
|
||||||
if (!sourceGetGlobal)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
|
||||||
if (!moduleOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
|
||||||
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
|
||||||
if (!denseAttr)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> perms;
|
|
||||||
perms.reserve(transposeOp.getPerms().size());
|
|
||||||
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
|
||||||
perms.push_back(attr.getInt());
|
|
||||||
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
|
||||||
if (failed(transposedAttr))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
|
||||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
MemRefType globalType = resultType;
|
|
||||||
|
|
||||||
auto globalName = sourceGlobal.getName().str() + "__folded_transpose";
|
|
||||||
unsigned suffix = 0;
|
|
||||||
while (moduleOp.lookupSymbol(globalName))
|
|
||||||
globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix);
|
|
||||||
|
|
||||||
auto visibility = rewriter.getStringAttr("private");
|
|
||||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
|
||||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
|
||||||
auto newGlobal = memref::GlobalOp::create(moduleBuilder,
|
|
||||||
transposeOp.getLoc(),
|
|
||||||
globalName,
|
|
||||||
visibility,
|
|
||||||
globalType,
|
|
||||||
*transposedAttr,
|
|
||||||
/*constant=*/true,
|
|
||||||
sourceGlobal.getAlignmentAttr());
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(transposeOp);
|
|
||||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
|
||||||
|
|
||||||
bool isAlwaysWeight =
|
|
||||||
!transposeOp->getUsers().empty()
|
|
||||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
|
||||||
if (isAlwaysWeight) {
|
|
||||||
markWeightAlways(newGlobal);
|
|
||||||
markWeightAlways(newGetGlobal);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PimFoldHostConstantsPass : PassWrapper<PimFoldHostConstantsPass, OperationPass<ModuleOp>> {
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass)
|
|
||||||
|
|
||||||
StringRef getArgument() const override { return "fold-pim-host-constants-pass"; }
|
|
||||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
|
||||||
|
|
||||||
LogicalResult initialize(MLIRContext* context) override {
|
|
||||||
RewritePatternSet owningPatterns(context);
|
|
||||||
for (auto* dialect : context->getLoadedDialects())
|
|
||||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
|
||||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
|
||||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
|
||||||
owningPatterns.add<FoldConstantTransposePattern>(context);
|
|
||||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
GreedyRewriteConfig config;
|
|
||||||
config.enableFolding();
|
|
||||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config)))
|
|
||||||
signalPassFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimFoldHostConstantsPass() { return std::make_unique<PimFoldHostConstantsPass>(); }
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -15,7 +15,7 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimFoldHostConstantsPass();
|
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createBufferizePimPass);
|
registerPass(createBufferizePimPass);
|
||||||
registerPass(createPimFoldHostConstantsPass);
|
registerPass(createPimConstantFoldingPass);
|
||||||
registerPass(createPimHostVerificationPass);
|
registerPass(createPimHostVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user