Files
Raptor/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp
T
NiccoloN 628dc630a4
Validate Operations / validate-operations (push) Has been cancelled
compact syntax for spatial tensor ops
better IR compaction after dcp merge
remove pim.mvm op
better memory report
2026-05-12 13:35:25 +02:00

528 lines
20 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <utility>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
};
} // namespace
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
rewriter.setInsertionPoint(sendOp);
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
rewriter.eraseOp(sendOp);
}
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
return;
}
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
rewriter.setInsertionPoint(receiveOp);
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, received);
}
static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp);
PimSendTensorOp::create(
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp);
}
static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
rewriter.setInsertionPoint(receiveTensorOp);
auto outputType = cast<ShapedType>(receiveTensorOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult();
Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(),
receiveTensorOp.getOutput().getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(receiveTensorOp, received);
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(extractRowsOp);
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(extractRowsOp, replacements);
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
for (auto concatOp : concatOps) {
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
continue;
SmallVector<Value> packedInputs;
bool changed = false;
rewriter.setInsertionPoint(concatOp);
for (unsigned index = 0; index < concatOp.getInputs().size();) {
Value input = concatOp.getInputs()[index];
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
++endIndex;
Value packedInput = createPackedExtractSliceTensor(
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
index = endIndex;
continue;
}
}
auto result = dyn_cast<OpResult>(input);
if (!result) {
packedInputs.push_back(input);
++index;
continue;
}
Operation* owner = result.getOwner();
unsigned startIndex = result.getResultNumber();
unsigned endIndex = index + 1;
while (endIndex < concatOp.getInputs().size()) {
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
if (!nextResult || nextResult.getOwner() != owner
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
break;
++endIndex;
}
unsigned count = endIndex - index;
Value packedInput;
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
if (packedInput) {
packedInputs.push_back(packedInput);
changed = true;
}
else {
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
}
index = endIndex;
}
if (!changed)
continue;
auto newConcat = pim::PimConcatOp::create(
rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
rewriter.eraseOp(op);
};
eraseUnusedOps(tensor::ConcatOp {});
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
target.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorMaterializationPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
signalPassFailure();
return;
}
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
signalPassFailure();
return;
}
}
compactSpatialTensorGroups(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
for (auto receiveOp : receiveOps) {
bool onlyPendingRemovalUsers = llvm::all_of(
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
if (onlyPendingRemovalUsers) {
markOpToRemove(receiveOp);
continue;
}
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
continue;
}
lowerChannelReceive(receiveOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveTensorOp> receiveTensorOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveTensorOp>())
receiveTensorOps.push_back(op);
for (auto receiveTensorOp : receiveTensorOps)
lowerChannelReceiveTensor(receiveTensorOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> sendOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
sendOps.push_back(op);
for (auto sendOp : sendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendTensorOp> sendTensorOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendTensorOp>())
sendTensorOps.push_back(op);
for (auto sendTensorOp : sendTensorOps)
lowerChannelSendTensor(sendTensorOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
extractRowsOps.push_back(op);
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
signalPassFailure();
return;
}
compactSpatialTensorGroups(funcOp, rewriter);
{
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
signalPassFailure();
return;
}
}
if (failed(verifySpatialToPimBoundary(moduleOp))) {
signalPassFailure();
return;
}
// Dump to file for debug
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return;
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
if (!dpsDefiningOp)
return;
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
return;
Value tiedValue = tiedOperand->get();
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
tiedValue.setType(newType);
self(tiedValue, newType, self);
};
funcOp.walk([&](PimVMMOp vmmOp) {
auto outTensorOperand = vmmOp.getOutputBuffer();
auto resultTensor = vmmOp.getOutput();
auto outShape = getTensorShape(outTensorOperand);
assert(isHVectorShape(outShape));
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
if (outTensorOperand == vmmOp.getInput()) {
rewriter.setInsertionPoint(vmmOp);
auto newOutputBuffer =
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
}
else {
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
}
resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
}
});
}
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc,
tensorType,
deviceTensor,
inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
};
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
}
}
}
return success();
}
void SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir