merge remote changes

This commit is contained in:
NiccoloN
2026-05-03 22:30:46 +02:00
parent b605585b1f
commit 62b0a6e19d
15 changed files with 1116 additions and 181 deletions

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -11,6 +12,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
@@ -183,6 +185,7 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
@@ -199,19 +202,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
auto source = toRemoveOp.getSource();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
@@ -245,6 +265,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
inst->erase();
return true;
}
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
@@ -306,6 +344,89 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
return cast<Value>(mapped);
}
bool sourceOpernadHasWeightAlways(Operation* op) {
if (op == nullptr)
return false;
Operation* source = nullptr;
do {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
return false;
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
auto tmpSource = extractSliceOp.getSource();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
auto tmpSource = extractRowsOp.getInput();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
auto tmpSource = expandShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
auto tmpSource = transposeOp.getData();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
auto tmpSource = collapseShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
source = constantOp;
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else {
op->dump();
llvm_unreachable("Global instruction not handle in func");
}
}
while (source == nullptr);
if (hasWeightAlways(source))
return true;
return false;
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
@@ -314,8 +435,14 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
instruction)
|| isa<func::ReturnOp>(instruction)
|| sourceOpernadHasWeightAlways(&instruction))
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });

View File

@@ -23,7 +23,10 @@ static Value extractSliceAt(
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
}
struct Split : OpConversionPattern<ONNXSplitOp> {
@@ -49,12 +52,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t sliceSize = resultType.getShape()[axis];
auto computeOp =
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
});
outputs.push_back(computeOp.getResult(0));
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize;
}

View File

@@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
Common.cpp
Patterns.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -0,0 +1,385 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
Location loc = extractSliceOp.getLoc();
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
if (spatCompute.getInputs().empty())
return failure();
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
return failure();
}
}
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
{
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
rewriter.eraseOp(extractSliceOp);
return success();
}
};
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
constantOp.getValueAttr(),
rewriter.getUnitAttr(),
{});
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
}
}
}
}
auto parent = constantOp->getParentOp();
rewriter.eraseOp(constantOp);
return success();
}
};
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
if (funcOp.getArguments().empty())
return failure();
if (llvm::all_of(funcOp.getArguments(),
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
return failure();
Location loc = funcOp.getLoc();
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
if (arg.getUses().empty())
continue;
rewriter.setInsertionPoint(funcOp.getOperation());
assert(isa<mlir::RankedTensorType>(arg.getType()));
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
else {
rewriter.setInsertionPoint(argUser);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(argUser);
argUses.set(toTensor);
rewriter.finalizeOpModification(argUser);
}
}
}
return success();
}
};
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,10 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}

View File

@@ -1,20 +1,26 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_os_ostream.h"
#include <cassert>
@@ -24,6 +30,7 @@
#include <utility>
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -53,7 +60,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
@@ -179,7 +186,22 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
inputType = tensorType;
}
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
rewriter.setInsertionPoint(extractRowsOp);
input = bufferization::ToTensorOp::create(
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
.getResult();
}
else {
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
return;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacements;
@@ -187,11 +209,16 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
rewriter.setInsertionPoint(extractRowsOp);
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
return;
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto rowSlice = tensor::ExtractSliceOp::create(
rewriter, extractRowsOp.getLoc(), cast<RankedTensorType>(output.getType()), extractRowsOp.getInput(), offsets, sizes, strides);
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
replacements.push_back(rowSlice.getResult());
}
@@ -205,6 +232,75 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op)
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
struct ReturnUseInfo {
size_t returnIndex;
SmallVector<Operation*> helperChain;
@@ -295,6 +391,20 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
return std::nullopt;
currentValue = helperCompute.getResult(0);
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
while (isChannelUseChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
@@ -419,21 +529,22 @@ static void cloneHelperChain(Value sourceValue,
}
}
static void emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes));
static Value emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
void SpatialToPimPass::runOnOperation() {
@@ -458,12 +569,21 @@ void SpatialToPimPass::runOnOperation() {
scf::SCFDialect,
BuiltinDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorToMemrefPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -489,7 +609,8 @@ void SpatialToPimPass::runOnOperation() {
}
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { receiveOps.push_back(op); });
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
for (auto receiveOp : receiveOps) {
bool onlyPendingRemovalUsers = llvm::all_of(
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
@@ -505,22 +626,26 @@ void SpatialToPimPass::runOnOperation() {
}
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { receiveManyOps.push_back(op); });
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
receiveManyOps.push_back(op);
for (auto receiveManyOp : receiveManyOps)
lowerChannelReceiveMany(receiveManyOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> sendOps;
funcOp.walk([&](spatial::SpatChannelSendOp op) { sendOps.push_back(op); });
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
sendOps.push_back(op);
for (auto sendOp : sendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { sendManyOps.push_back(op); });
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
sendManyOps.push_back(op);
for (auto sendManyOp : sendManyOps)
lowerChannelSendMany(sendManyOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
funcOp.walk([&](spatial::SpatExtractRowsOp op) { extractRowsOps.push_back(op); });
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
extractRowsOps.push_back(op);
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
@@ -560,6 +685,36 @@ void SpatialToPimPass::runOnOperation() {
assert(false && "tracked op removal reached a cycle or missed dependency");
}
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
for (auto concatOp : remainingConcatOps)
lowerConcat(concatOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
for (auto receiveOp : remainingReceiveOps)
lowerChannelReceive(receiveOp, rewriter);
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
for (auto receiveManyOp : remainingReceiveManyOps)
lowerChannelReceiveMany(receiveManyOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
for (auto sendOp : remainingSendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
for (auto sendManyOp : remainingSendManyOps)
lowerChannelSendMany(sendManyOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
for (auto extractRowsOp : remainingExtractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
// Dump to file for debug
bool hasSpatialOps = false;
moduleOp.walk([&](Operation* op) {
@@ -579,6 +734,13 @@ void SpatialToPimPass::runOnOperation() {
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
return;
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return;
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
@@ -616,9 +778,9 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[returnUse->returnIndex];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
@@ -637,8 +799,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
@@ -654,13 +816,13 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
Value outputTensor = outputTensors[concatReturnUse->returnIndex];
auto outputType = cast<ShapedType>(outputTensor.getType());
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
if (concatReturnUse->helperChain.empty()) {
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
@@ -671,7 +833,15 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
continue;
}
auto storedType = cast<RankedTensorType>(yieldValue.getType());
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
@@ -701,19 +871,18 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
rewriter.setInsertionPointAfterValue(yieldValue);
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
outputTensor = emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
}
continue;
}
@@ -848,6 +1017,26 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
@@ -922,17 +1111,33 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue);
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputName = "output_" + std::to_string(index);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back(
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
}
}
}
@@ -940,11 +1145,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(valueToReplace.getType());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -953,86 +1158,27 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
loc,
tensorType,
deviceTensor,
hostTensor,
inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
};
// Replace input tensors with memRefs
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
BlockArgument tensorArg = funcOp.getArgument(i);
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource;
int64_t elementsOffset = 0;
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
assert("Extracting slice non-contiguous in memory"
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
for (size_t i = 0; i < sliceOffsets.size(); i++) {
int64_t partialOffset = sliceOffsets[i];
if (partialOffset != 0)
for (size_t j = i + 1; j < sourceShape.size(); j++)
partialOffset *= sourceShape[j];
elementsOffset += partialOffset;
}
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
sliceOpsToRemove.insert(sliceOp);
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
}
else
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Values already produced inside the device-side graph must not be
// copied back through a host-to-device staging step here.
if (isa<spatial::SpatCompute, spatial::SpatChannelReceiveOp>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
}
}
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
return success();
}
@@ -1050,7 +1196,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp>(onlyUser) || isChannelUseChainOp(onlyUser);
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
@@ -1062,6 +1208,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(computeOp);
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
@@ -1070,12 +1223,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
rewriter.setInsertionPoint(returnOp);
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
}