Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -0,0 +1,227 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
|
||||
name = (baseName + "_" + Twine(suffix++)).str();
|
||||
return name;
|
||||
}
|
||||
|
||||
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ModuleOp moduleOp,
|
||||
StringRef baseName,
|
||||
MemRefType type,
|
||||
Attribute initialValue = {},
|
||||
UnitAttr constant = {}) {
|
||||
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(type),
|
||||
initialValue,
|
||||
constant,
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
|
||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
for (auto& uses : extractSliceOp->getUses()) {
|
||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
return failure();
|
||||
}
|
||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
|
||||
|
||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.eraseOp(extractSliceOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
|
||||
|
||||
if (funcOp.getArguments().empty())
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(funcOp.getArguments(),
|
||||
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
|
||||
return failure();
|
||||
|
||||
Location loc = funcOp.getLoc();
|
||||
|
||||
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
|
||||
if (arg.getUses().empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(funcOp.getOperation());
|
||||
|
||||
assert(isa<mlir::RankedTensorType>(arg.getType()));
|
||||
|
||||
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
||||
|
||||
std::string baseName = ("arg_" + Twine(index)).str();
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
|
||||
rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||
auto argUser = argUses.getOwner();
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
|
||||
}
|
||||
else {
|
||||
rewriter.setInsertionPoint(argUser);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
rewriter.startOpModification(argUser);
|
||||
argUses.set(toTensor);
|
||||
rewriter.finalizeOpModification(argUser);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user