424 lines
19 KiB
C++
424 lines
19 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include "Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
|
unsigned inputCount = compute.getInputs().size();
|
|
if (inputCount == 0)
|
|
return std::nullopt;
|
|
|
|
unsigned inputBegin = compute->getNumOperands() - inputCount;
|
|
if (operandNumber < inputBegin)
|
|
return std::nullopt;
|
|
return operandNumber - inputBegin;
|
|
}
|
|
|
|
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
|
|
unsigned inputCount = computeBatch.getInputs().size();
|
|
if (inputCount == 0)
|
|
return std::nullopt;
|
|
|
|
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
|
|
if (operandNumber < inputBegin)
|
|
return std::nullopt;
|
|
return operandNumber - inputBegin;
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
|
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
|
return failure();
|
|
|
|
for (auto& uses : extractSliceOp->getUses()) {
|
|
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
|
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
|
return failure();
|
|
}
|
|
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
|
|
|
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
|
|
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
|
|
if (BBArgValue.use_empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
|
}
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
|
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
|
|
if (BBArgValue.use_empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
|
}
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
else {
|
|
{
|
|
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
|
}
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
|
}
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
rewriter.eraseOp(extractSliceOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
|
static int i = 0;
|
|
Location loc = constantOp.getLoc();
|
|
|
|
if (hasWeightAlways(constantOp))
|
|
return failure();
|
|
|
|
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
|
return failure();
|
|
|
|
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
|
if (isa<spatial::SpatCompute>(op))
|
|
return false;
|
|
if (isa<func::FuncOp>(op->getParentOp()))
|
|
return true;
|
|
return false;
|
|
}))
|
|
return failure();
|
|
|
|
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
|
|
|
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
|
|
|
if (constRankedTensorType) {
|
|
mlir::MemRefType memRefType =
|
|
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
|
std::string argName = "const_" + std::to_string(i++);
|
|
memref::GlobalOp::create(rewriter,
|
|
loc,
|
|
rewriter.getStringAttr(argName),
|
|
rewriter.getStringAttr("private"),
|
|
TypeAttr::get(memRefType),
|
|
constantOp.getValueAttr(),
|
|
rewriter.getUnitAttr(),
|
|
{});
|
|
|
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
|
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
auto constUsers = constUses.getOwner();
|
|
|
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
}
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
}
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
else {
|
|
{
|
|
|
|
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
}
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
}
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
|
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
|
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
auto constUsers = constUses.getOwner();
|
|
|
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
auto newConst = rewriter.clone(*constantOp);
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
auto newConst = rewriter.clone(*constantOp);
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
|
if (!mapSpatComputeToConst.contains(parent)) {
|
|
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
|
auto newConst = rewriter.clone(*constantOp);
|
|
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
|
}
|
|
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
|
}
|
|
else {
|
|
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
|
assert(batchParent && "Global Constant used direcly not within a compute");
|
|
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
|
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
|
auto newConst = rewriter.clone(*constantOp);
|
|
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
|
}
|
|
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
|
}
|
|
}
|
|
}
|
|
rewriter.eraseOp(constantOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
|
|
|
|
if (funcOp.getArguments().empty())
|
|
return failure();
|
|
|
|
if (llvm::all_of(funcOp.getArguments(),
|
|
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
|
|
return failure();
|
|
|
|
Location loc = funcOp.getLoc();
|
|
|
|
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
|
|
if (arg.getUses().empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(funcOp.getOperation());
|
|
|
|
assert(isa<mlir::RankedTensorType>(arg.getType()));
|
|
|
|
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
|
|
mlir::MemRefType memRefType =
|
|
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
|
|
|
std::string argName = "arg_" + std::to_string(index);
|
|
|
|
memref::GlobalOp::create(rewriter,
|
|
loc,
|
|
rewriter.getStringAttr(argName),
|
|
rewriter.getStringAttr("private"),
|
|
TypeAttr::get(memRefType),
|
|
{},
|
|
{},
|
|
{});
|
|
|
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
|
auto argUser = argUses.getOwner();
|
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
|
|
rewriter.startOpModification(spatCompute.getOperation());
|
|
BBArgValue.replaceAllUsesWith(toTensor);
|
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
}
|
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
|
if (!inputIndex)
|
|
return failure();
|
|
auto BBArgIndex = *inputIndex;
|
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
|
|
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
BBArgValue.replaceAllUsesWith(toTensor);
|
|
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
}
|
|
else {
|
|
rewriter.setInsertionPoint(argUser);
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
rewriter.startOpModification(argUser);
|
|
argUses.set(toTensor);
|
|
rewriter.finalizeOpModification(argUser);
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|