Merge remote-tracking branch 'origin/main'

This commit is contained in:
NiccoloN
2026-05-08 13:12:47 +02:00
4 changed files with 52 additions and 31 deletions

2
.gitignore vendored
View File

@@ -9,7 +9,9 @@ AGENTS.md
CMakeUserPresets.json CMakeUserPresets.json
build build
build_release
cmake-build-debug cmake-build-debug
cmake-build-release cmake-build-release
compile.sh
**/__* **/__*

View File

@@ -178,32 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
} }
} }
// void PimMemory::report(llvm::raw_ostream& file) {
// std::vector orderedList(globalMemEntriesMap.begin(), globalMemEntriesMap.end());
// std::sort(
// orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) { return lft.second.address < rgt.second.address;
// });
// auto newEnd = std::unique(orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) {
// return (lft.first.getDefiningOp() == rgt.first.getDefiningOp()) && (lft.second.address == rgt.second.address);
// });
// orderedList.erase(newEnd, orderedList.end());
// mlir::OpPrintingFlags flags;
// flags.assumeVerified(true);
// for (auto& [value, memEntry] : orderedList) {
// if (auto op = value.getDefiningOp()) {
// file.indent(4) << op << ": ";
// op->print(file, flags);
// file << "\n";
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
// }
// else {
// file.indent(4) << value << "\n";
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
// }
// }
// }
void PimMemory::remove(mlir::Value val) { void PimMemory::remove(mlir::Value val) {
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end()) if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())

View File

@@ -49,6 +49,7 @@ private:
void annotateWeightsConstants(func::FuncOp funcOp) const; void annotateWeightsConstants(func::FuncOp funcOp) const;
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp); LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
void populateEmptyFunction(func::FuncOp funcOp);
}; };
} // namespace } // namespace
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
auto loc = batchOp.getLoc(); auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp); rewriter.setInsertionPoint(batchOp);
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes( computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())}); {static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
blockArgTypes.push_back(arg.getType()); blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc); blockArgLocs.push_back(loc);
} }
auto* newBlock = rewriter.createBlock( auto* newBlock =
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper; IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
populateEmptyFunction(*entryFunc);
if (failed(encapsulateGlobalInstruction(*entryFunc))) { if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure(); signalPassFailure();
return; return;
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
while (keep) { while (keep) {
keep = false; keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>( if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
instruction)
|| isa<func::ReturnOp>(instruction)) || isa<func::ReturnOp>(instruction))
continue; continue;
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
return success(); return success();
} }
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, bbArg);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
rewriter.setInsertionPointToEnd(BB);
for (Operation& inst : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
rewriter.clone(inst, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
inst.dropAllUses();
rewriter.eraseOp(&inst);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); } std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -23,6 +23,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) { if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
unsigned inputCount = compute.getInputs().size(); unsigned inputCount = compute.getInputs().size();