diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 080a335..102e2ae 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -3,9 +3,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" @@ -53,9 +55,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { SmallDenseMap globalConstants; SmallVector, 16> globalAliases; + SmallVector args; + + + for (mlir::Value arg : funcOp.getArguments()){ + gatherMemEntry(arg); + args.push_back(arg); + } + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { if (!hasWeightAlways(getGlobalOp)) { auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (globalMemrefOp.getName().starts_with("arg")){ + StringRef indexStr = globalMemrefOp.getName().substr(4); + int index = 0; + llvm::to_integer(indexStr,index, 10); + globalAliases.push_back({getGlobalOp.getResult(), args[index]}); + } auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); if (inserted) gatherMemEntry(getGlobalOp.getResult()); @@ -64,8 +80,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { } }); - for (mlir::Value arg : funcOp.getArguments()) - gatherMemEntry(arg); funcOp.walk([&](memref::AllocOp allocOp) { if (!allocOp->getParentOfType()) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 28e7078..eaee28c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -148,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() { llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); + encapsulateGlobalInstruction(*entryFunc); if (failed(promoteConstantInputsToWeights(*entryFunc))) { @@ -166,19 +167,46 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); - if (isa_and_present(source.getDefiningOp())) { - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - mapper.map(source, BB->getArgument(0)); - auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; + } + return false; +} + +bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) { + if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present(inst)) { + for (auto& use : toRemoveOp->getUses()) { + auto users = use.getOwner(); + if (auto spatCompUser = dyn_cast(users)) { + unsigned int poistionUses = use.getOperandNumber(); + if (poistionUses < spatCompUser.getInputs().getBeginOperandIndex()) + return false; + }else { + return false; + } } + auto source = toRemoveOp.getSource(); + rewriter.setInsertionPointAfter(toRemoveOp); + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; } return false; } @@ -187,8 +215,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { auto sources = toRemoveOp.getInputs(); rewriter.setInsertionPointAfter(toRemoveOp); - if (llvm::any_of( - sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { + if (llvm::any_of(sources, + [](auto source) { return isa_and_present(source.getDefiningOp()); })) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); SmallVector sourceTypes; SmallVector sourceLoc; @@ -277,8 +305,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - keep |= encapsulator( - rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); + keep |= encapsulateSlice(rewriter, loc, &instruction); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); @@ -324,8 +351,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { rewriter.setInsertionPointAfter(compute.getOperation()); - auto newCompute = - spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); + auto newCompute = spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); newCompute.getProperties().setOperandSegmentSizes( {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index 5dbecd2..a84a2ef 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -1,11 +1,19 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -13,6 +21,96 @@ using namespace mlir; namespace onnx_mlir { namespace { +struct ArithConstToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { + static int i = 0; + Location loc = constantOp.getLoc(); + + if (hasWeightAlways(constantOp)) + return failure(); + + if (!isa(constantOp->getParentOp())) + return failure(); + + rewriter.setInsertionPoint(constantOp->getParentOfType()); + + auto constRankedTensorType = llvm::dyn_cast(constantOp.getType()); + + if (constRankedTensorType) { + mlir::MemRefType memRefType = + mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); + std::string argName = "const_" + std::to_string(i++); + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + constantOp.getValueAttr(), + rewriter.getUnitAttr(), + {}); + + for (auto& constUses : constantOp->getUses()) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + llvm_unreachable("Who are using const globally"); + } + } + } + else if (constantOp.getType().isIntOrIndexOrFloat()) { + llvm::DenseMap mapSpatComputeToConst; + + for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(newConst->getResult(0)); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + auto parent = constUsers->getParentOfType(); + assert(parent && "Global Constant used direcly not within a compute"); + if (!mapSpatComputeToConst.contains(parent)) { + rewriter.setInsertionPoint(&parent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + mapSpatComputeToConst.insert({parent, newConst->getResult(0)}); + } + constUses.set(mapSpatComputeToConst[parent]); + } + } + } + auto parent = constantOp->getParentOp(); + rewriter.eraseOp(constantOp); + return success(); + } +}; + struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -42,13 +140,13 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ba7beac..d9c7100 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" @@ -165,10 +166,7 @@ void SpatialToPimPass::runOnOperation() { RewritePatternSet patterns(ctx); populateGlobalTensorToMemrefPatterns(patterns); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { - signalPassFailure(); - return; - } + walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto entryFunc = getPimEntryFunc(moduleOp); @@ -504,6 +502,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; + for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle"); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 976f418..6672d3b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -95,23 +95,22 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { auto expectedPrintedValue = currentComputeId + 1; bool rangePrinted = false; cI++; - for (; cI <= lastIndex; ++cI){ + for (; cI <= lastIndex; ++cI) { auto candidateToPrint = std::get<0>(collectedData[cI]); - if (candidateToPrint == expectedPrintedValue){ + if (candidateToPrint == expectedPrintedValue) { expectedPrintedValue = candidateToPrint + 1; rangePrinted = true; - } else { - if (rangePrinted) { + } + else { + if (rangePrinted) os << " - " << expectedPrintedValue - 1; - } os << " , " << candidateToPrint; rangePrinted = false; expectedPrintedValue = candidateToPrint + 1; } } - if (rangePrinted && currentComputeId != expectedPrintedValue - 1){ - os << " - " << expectedPrintedValue - 1; - } + if (rangePrinted && currentComputeId != expectedPrintedValue - 1) + os << " - " << expectedPrintedValue - 1; os << " :\n"; os << "\tNumber of instructions " << currentNumInst << "\n"; @@ -193,11 +192,34 @@ public: LogicalResult initialize(MLIRContext* context) override { return success(); } + void verifyOrderAssumption(std::vector& dominanceOrderCompute) { + uint64_t computeNumber = 0; + llvm::DenseSet visited; + mlir::func::FuncOp funcOp = getOperation(); + for (auto spatCompute : funcOp.getOps()) + computeNumber++; + + assert(computeNumber == dominanceOrderCompute.size()); + + for(auto domCompute : dominanceOrderCompute){ + visited.insert(domCompute); + for(auto domInput : domCompute.getInputs() ){ + if(auto domImputAsCompute = dyn_cast_if_present(domInput.getDefiningOp())){ + assert(visited.contains(domImputAsCompute) && "Dominance order violated\n"); + } + } + } + + } + void runOnOperation() override { DCPAnalysisResult& analysisResult = getAnalysis().getResult(); auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu; auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap; + func::FuncOp func = getOperation(); + verifyOrderAssumption(analysisResult.dominanceOrderCompute); + for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); if (!cpuToNewComputeMap.contains(cpu)) { @@ -219,11 +241,19 @@ public: } for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { - for (auto users : computeNodeToRemove->getUsers()) + if (!computeNodeToRemove->use_empty()) { + llvm::dbgs() << "Full module\n"; + computeNodeToRemove->getParentOfType()->dump(); + + llvm::dbgs() << "Compute with uses:\n"; + computeNodeToRemove.dump(); + } + for (auto users : computeNodeToRemove->getUsers()) { + llvm::dbgs() << "Users:\n"; users->dump(); + } computeNodeToRemove.erase(); } - func::FuncOp func = getOperation(); dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); generateReport(func, "spatial1_dcp_merged_report"); }