Merge remote-tracking branch 'origin/main'
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,7 +9,9 @@ AGENTS.md
|
|||||||
CMakeUserPresets.json
|
CMakeUserPresets.json
|
||||||
|
|
||||||
build
|
build
|
||||||
|
build_release
|
||||||
cmake-build-debug
|
cmake-build-debug
|
||||||
cmake-build-release
|
cmake-build-release
|
||||||
|
compile.sh
|
||||||
|
|
||||||
**/__*
|
**/__*
|
||||||
|
|||||||
@@ -178,32 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// void PimMemory::report(llvm::raw_ostream& file) {
|
|
||||||
// std::vector orderedList(globalMemEntriesMap.begin(), globalMemEntriesMap.end());
|
|
||||||
// std::sort(
|
|
||||||
// orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) { return lft.second.address < rgt.second.address;
|
|
||||||
// });
|
|
||||||
// auto newEnd = std::unique(orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) {
|
|
||||||
// return (lft.first.getDefiningOp() == rgt.first.getDefiningOp()) && (lft.second.address == rgt.second.address);
|
|
||||||
// });
|
|
||||||
// orderedList.erase(newEnd, orderedList.end());
|
|
||||||
// mlir::OpPrintingFlags flags;
|
|
||||||
// flags.assumeVerified(true);
|
|
||||||
// for (auto& [value, memEntry] : orderedList) {
|
|
||||||
// if (auto op = value.getDefiningOp()) {
|
|
||||||
// file.indent(4) << op << ": ";
|
|
||||||
// op->print(file, flags);
|
|
||||||
// file << "\n";
|
|
||||||
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
|
|
||||||
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
|
|
||||||
// }
|
|
||||||
// else {
|
|
||||||
// file.indent(4) << value << "\n";
|
|
||||||
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
|
|
||||||
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
void PimMemory::remove(mlir::Value val) {
|
void PimMemory::remove(mlir::Value val) {
|
||||||
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ private:
|
|||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||||
|
void populateEmptyFunction(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
auto loc = batchOp.getLoc();
|
auto loc = batchOp.getLoc();
|
||||||
rewriter.setInsertionPoint(batchOp);
|
rewriter.setInsertionPoint(batchOp);
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
auto computeOp =
|
||||||
|
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||||
computeOp.getProperties().setOperandSegmentSizes(
|
computeOp.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||||
|
|
||||||
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|||||||
blockArgTypes.push_back(arg.getType());
|
blockArgTypes.push_back(arg.getType());
|
||||||
blockArgLocs.push_back(loc);
|
blockArgLocs.push_back(loc);
|
||||||
}
|
}
|
||||||
auto* newBlock = rewriter.createBlock(
|
auto* newBlock =
|
||||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||||
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
|
|||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
||||||
instruction)
|
|
||||||
|| isa<func::ReturnOp>(instruction))
|
|| isa<func::ReturnOp>(instruction))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
IRMapping mapper;
|
||||||
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||||
|
if (!computes.empty())
|
||||||
|
return;
|
||||||
|
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
|
||||||
|
SmallVector<Type> sourceTypes;
|
||||||
|
SmallVector<Location> sourceLoc;
|
||||||
|
for (auto source : funcOp.getArguments()) {
|
||||||
|
sourceTypes.push_back(source.getType());
|
||||||
|
sourceLoc.push_back(source.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newCompute = spatial::SpatCompute::create(
|
||||||
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||||
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||||
|
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
||||||
|
mapper.map(computeArg, bbArg);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
||||||
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
|
for (Operation& inst : funcOp.getOps())
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
||||||
|
rewriter.clone(inst, mapper);
|
||||||
|
|
||||||
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||||
|
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||||
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||||
|
|
||||||
|
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
||||||
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
||||||
|
inst.dropAllUses();
|
||||||
|
rewriter.eraseOp(&inst);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||||
|
returnOp.setOperand(index, computeResult);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
||||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||||
unsigned inputCount = compute.getInputs().size();
|
unsigned inputCount = compute.getInputs().size();
|
||||||
|
|||||||
Reference in New Issue
Block a user