saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -23,54 +23,28 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Pass/PIMPasses.h"
|
||||
#include "SpatialToPimPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
namespace raptor {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace raptor
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
@@ -150,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
@@ -197,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
addReturnOutputBuffers(returnOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -217,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -266,7 +240,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||
eraseOpsToRemove();
|
||||
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
@@ -309,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
@@ -343,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
@@ -387,18 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::eraseOpsToRemove() {
|
||||
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||
for (Operation* op : operationsToRemove) {
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user