saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-22 07:27:54 +02:00
parent 43ed3914b8
commit 074eb183c7
8 changed files with 142 additions and 176 deletions
@@ -23,54 +23,28 @@
#include <cassert>
#include <utility>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "Pass/PIMPasses.h"
#include "SpatialToPimPass.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
};
} // namespace
} // namespace raptor
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
@@ -150,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
}
void SpatialToPimPass::runOnOperation() {
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0;
outputTensors.clear();
operationsToRemove.clear();
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
@@ -197,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
addReturnOutputBuffers(returnOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure();
return;
@@ -217,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure();
return;
@@ -266,7 +240,7 @@ void SpatialToPimPass::runOnOperation() {
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
replaceReturnWithOutputBuffers(returnOp, rewriter);
eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx);
@@ -309,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
@@ -343,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
});
}
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
@@ -387,18 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
return success();
}
void SpatialToPimPass::markOpToRemove(Operation* op) {
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
void SpatialToPimPass::eraseOpsToRemove() {
void raptor::SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
} // namespace onnx_mlir