Files
Raptor/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp
T
2026-05-22 21:52:28 +02:00

380 lines
16 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <utility>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "Pass/PIMPasses.h"
#include "SpatialToPimPass.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
continue;
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
return globalOp;
}
std::string nameStem;
llvm::raw_string_ostream nameStream(nameStem);
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
nameStream.flush();
std::string symbolName = nameStem;
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
symbolName = (nameStem + "_" + Twine(suffix++)).str();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
zeroAttr,
rewriter.getUnitAttr(),
IntegerAttr {});
}
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
return PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
tensorType,
outputBuffer,
zeroValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
sizeAttr)
.getOutput();
return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
.getOutput();
}
static Value
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector");
assert(shape[1] <= static_cast<int64_t>(crossbarSize) && "vector width must fit in one crossbar");
if (shape[1] == static_cast<int64_t>(crossbarSize))
return vector;
auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
}
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0;
outputTensors.clear();
operationsToRemove.clear();
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
target.addLegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelReceiveTensorBatchOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet initialPatterns(ctx);
populateWithGenerated(initialPatterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
signalPassFailure();
return;
}
RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure();
return;
}
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure();
return;
}
}
RewritePatternSet initialTensorPackingPatterns(ctx);
populateTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
for (auto receiveOp : receiveOps) {
bool onlyPendingRemovalUsers = llvm::all_of(
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
if (onlyPendingRemovalUsers) {
markOpToRemove(receiveOp);
continue;
}
}
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
signalPassFailure();
return;
}
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter);
eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure();
return;
}
if (failed(verifySpatialToPimBoundary(moduleOp))) {
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
signalPassFailure();
return;
}
// Dump to file for debug
dumpModule(moduleOp, "pim0");
}
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
assert(isHVectorShape(outputShape) && "expected a horizontal vector output");
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp);
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
auto paddedOutputType = RankedTensorType::get(
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
? vmmOp.getOutputBuffer()
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
vmmOp.getInputMutable().assign(paddedInput);
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
vmmOp.getOutput().setType(paddedOutputType);
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
return;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp =
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
});
}
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
if (!hasByteSizedElementType(elementType))
return;
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc,
tensorType,
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
deviceTensor,
inputTensor,
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
};
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
}
}
}
return success();
}
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
void raptor::SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
} // namespace onnx_mlir