Single Concat Fix
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,7 +9,9 @@ AGENTS.md
|
||||
CMakeUserPresets.json
|
||||
|
||||
build
|
||||
build_release
|
||||
cmake-build-debug
|
||||
cmake-build-release
|
||||
compile.sh
|
||||
|
||||
**/__*
|
||||
|
||||
@@ -169,32 +169,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
|
||||
}
|
||||
}
|
||||
|
||||
// void PimMemory::report(llvm::raw_ostream& file) {
|
||||
// std::vector orderedList(globalMemEntriesMap.begin(), globalMemEntriesMap.end());
|
||||
// std::sort(
|
||||
// orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) { return lft.second.address < rgt.second.address;
|
||||
// });
|
||||
// auto newEnd = std::unique(orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) {
|
||||
// return (lft.first.getDefiningOp() == rgt.first.getDefiningOp()) && (lft.second.address == rgt.second.address);
|
||||
// });
|
||||
// orderedList.erase(newEnd, orderedList.end());
|
||||
// mlir::OpPrintingFlags flags;
|
||||
// flags.assumeVerified(true);
|
||||
// for (auto& [value, memEntry] : orderedList) {
|
||||
// if (auto op = value.getDefiningOp()) {
|
||||
// file.indent(4) << op << ": ";
|
||||
// op->print(file, flags);
|
||||
// file << "\n";
|
||||
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
|
||||
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
|
||||
// }
|
||||
// else {
|
||||
// file.indent(4) << value << "\n";
|
||||
// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n";
|
||||
// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n";
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
void PimMemory::remove(mlir::Value val) {
|
||||
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||
|
||||
@@ -49,6 +49,7 @@ private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
void populateEmptyFunction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||
instruction)
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
||||
|| isa<func::ReturnOp>(instruction))
|
||||
continue;
|
||||
|
||||
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, bbArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
for (Operation& inst : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
||||
rewriter.clone(inst, mapper);
|
||||
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
||||
inst.dropAllUses();
|
||||
rewriter.eraseOp(&inst);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -23,6 +23,16 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct EraseSpatNopPattern : public mlir::OpRewritePattern<spatial::SpatNopOp> {
|
||||
using OpRewritePattern<spatial::SpatNopOp>::OpRewritePattern;
|
||||
mlir::LogicalResult matchAndRewrite(spatial::SpatNopOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
@@ -416,7 +426,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.add<EraseSpatNopPattern, MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user