Files
Raptor/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp
NiccoloN bb6dcd38a3 replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
2026-03-20 13:30:53 +01:00

552 lines
23 KiB
C++

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_os_ostream.h"
#include <cassert>
#include <filesystem>
#include <string>
#include <utility>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter);
allocateAndInitializeCoreLocalVariables(funcOp, rewriter);
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp);
runOnReceiveOp(receiveOp, rewriter);
}
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
operationsToRemove.push_back(computeOp);
runOnComputeOp(computeOp, rewriter);
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
// Remove all ComputeOps
for (auto opToRemove : llvm::reverse(operationsToRemove)) {
if (!opToRemove->use_empty()) {
opToRemove->dump();
for (auto user : opToRemove->getUsers())
user->dump();
assert(false && "opToRemove should be unused at this point");
}
rewriter.eraseOp(opToRemove);
}
// Dump to file for debug
dumpModule(moduleOp, "pim");
}
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty())
continue;
auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0;
size_t numElements = yieldType.getNumElements();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(numElements * elementSize));
continue;
}
if (isa<tensor::ConcatOp>(resultUser) || isa<spatial::SpatImgConcatOp>(resultUser)) {
auto concatOp = resultUser;
auto concatValue = concatOp->getResult(0);
auto concatUses = concatValue.getUses();
auto numConcatUses = rangeLength(concatUses);
if (numConcatUses == 1) {
OpOperand& concatUse = *concatUses.begin();
Operation* concatUser = concatUse.getOwner();
if (isa<func::ReturnOp>(concatUser)) {
size_t concatIndexInReturn = concatUse.getOperandNumber();
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
size_t offset = 0;
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
offset += cast<ShapedType>(operand.getType()).getNumElements()
* cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
yieldValue,
rewriter.getI32IntegerAttr(offset),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize));
continue;
}
}
}
}
// If this pattern was not found, then create a channel and send the value
// 1. Create a new ChannelOp
rewriter.setInsertionPoint(computeOp);
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel
// If this result is used by more than one user, then use a "Broadcast"
// channel operation. However, there is a special case: we have a single
// user (a ReshapeOp) which in turn is used by multiple ComputeOps. In this
// case, we need to use a "Broadcast" channel operation. `addReceiveOps`
// will detect this case and update `useBroadcastOp` accordingly.
bool useBroadcastOp = (numResultUses > 1);
addReceiveOps(result, channelOp, yieldType, useBroadcastOp, rewriter);
// 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue);
if (useBroadcastOp)
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
else
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
}
// Use `HaltOp` instead of `YieldOp`
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
// Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return;
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
if (!dpsDefiningOp)
return;
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
return;
Value tiedValue = tiedOperand->get();
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
tiedValue.setType(newType);
self(tiedValue, newType, self);
};
funcOp.walk([&](PimVMMOp vmmOp) {
auto outTensorOperand = vmmOp.getOutBuf();
auto resultTensor = vmmOp.getOutRes();
auto outShape = getTensorShape(outTensorOperand);
assert(isHVectorShape(outShape));
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
}
});
}
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue);
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
}
}
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(valueToReplace.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc,
tensorType,
deviceTensor,
hostTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
};
// Replace input tensors with memRefs
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
BlockArgument tensorArg = funcOp.getArgument(i);
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc);
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
funcOp.eraseArgument(i);
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource;
int64_t elementsOffset = 0;
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
assert("Extracting slice non-contiguous in memory"
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
for (size_t i = 0; i < sliceOffsets.size(); i++) {
int64_t partialOffset = sliceOffsets[i];
if (partialOffset != 0)
for (size_t j = i + 1; j < sourceShape.size(); j++)
partialOffset *= sourceShape[j];
elementsOffset += partialOffset;
}
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
sliceOpsToRemove.insert(sliceOp);
}
else
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
}
}
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
}
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex,
spatial::SpatChannelNewOp& channel,
Type& tensorType,
bool useBroadcastOp,
IRRewriter& rewriter) {
auto& computeBlock = computeOp.getRegion().front();
//(remember that WeightedCompute have weights as first operands, however these
// weights are not included in the block arguments. Thus, when indexing the
// block argument we need to remove the weights count)
auto computeWeightsCount = computeOp.getWeights().size();
auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount);
// Receive the tensor just before the first use of the value
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue;
if (useBroadcastOp)
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
blockArg.replaceAllUsesWith(receivedValue);
}
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel,
Type& channelTensorType,
bool& useBroadcastOp,
IRRewriter& rewriter) {
auto sourceOpUses = channelSourceOp.getUses();
// Check if we need to update `useBroadcastOp` to true, in the case of a reshapeOp with multiple users
if (useBroadcastOp == false) {
// if useBroadcastOp is false, then sourceOp must have only one user
assert(rangeLength(sourceOpUses) == 1);
if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(sourceOpUses.begin()->getOwner())) {
auto reshapeOpUses = reshapeOp.getOutput().getUses();
auto reshapeOpUsesCount = rangeLength(reshapeOpUses);
if (reshapeOpUsesCount > 1)
useBroadcastOp = true;
}
}
for (auto& resultUse : sourceOpUses) {
// The user must be a ComputeOp, or a reshapeOp which can be used by many ComputeOps
spatial::SpatWeightedCompute computeUser = dyn_cast<spatial::SpatWeightedCompute>(resultUse.getOwner());
if (computeUser) {
replaceBlockArgumentWithRecvOp(
computeUser, resultUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
continue;
}
if (!computeUser) {
auto reshapeOp = dyn_cast<tosa::ReshapeOp>(resultUse.getOwner());
if (!reshapeOp) {
channelSourceOp.getDefiningOp()->getParentOp()->getParentOp()->dump();
resultUse.getOwner()->dump();
llvm_unreachable("User of Value that now needs to be received by channel is not a ComputeOp nor a ReshapeOp");
}
// The tensorType now becomes the one of the reshapeOp
channelTensorType = reshapeOp.getResult().getType();
for (auto& reshapeUse : reshapeOp.getOutput().getUses()) {
computeUser = dyn_cast<spatial::SpatWeightedCompute>(reshapeUse.getOwner());
if (!computeUser)
llvm_unreachable("ReshapeOp users must be ComputeOps");
replaceBlockArgumentWithRecvOp(
computeUser, reshapeUse.getOperandNumber(), channel, channelTensorType, useBroadcastOp, rewriter);
}
// Remove the reshapeOp, so that the sourceOp has no users
operationsToRemove.push_back(reshapeOp);
}
}
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
for (auto it : llvm::enumerate(returnOp.getOperands())) {
Operation* returnOperand = it.value().getDefiningOp();
size_t orderWithinReturn = it.index();
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
// If the operand is a concatenation operation and the returnOp was the only
// user of the returnOperand, we can safely remove it
if (isAConcatOp(returnOperand)) {
auto returnOperandUses = it.value().getUses();
if (rangeLength(returnOperandUses) == 0)
rewriter.eraseOp(returnOperand);
}
}
}
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter);
if (failed(sendOpOpt))
llvm_unreachable("ChannelReceiveOp has no matching SendOp");
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
auto tensorType = receiveOp.getType();
Value receiveRes = receiveOp.getResult();
// Check if the receiveOp value has more than one user
auto receiveUses = receiveRes.getUses();
auto receiveUsesCount = rangeLength(receiveUses);
assert(receiveUsesCount > 0);
bool useBroadcastOp = receiveUsesCount > 1;
addReceiveOps(receiveRes, channel, tensorType, useBroadcastOp, rewriter);
if (useBroadcastOp) {
// When receiving, we actually noticed that the value has more than one
// user. This means that we need to get the replace the original SendOp with
// a BroadcastSendOp
rewriter.setInsertionPoint(sendOp);
rewriter.replaceOpWithNewOp<spatial::SpatChannelBroadcastSendOp>(sendOp, sendOp.getChannel(), sendOp.getData());
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir