diff --git a/.gitignore b/.gitignore index dc780e7..8bb3643 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ **/.vscode .claude +.codex AGENTS.md CMakeUserPresets.json diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 116eb32..0615d35 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index b15074b..07bebf3 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -5,9 +5,11 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" @@ -55,9 +57,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { SmallDenseMap globalConstants; SmallVector, 16> globalAliases; + SmallVector args; + + + for (mlir::Value arg : funcOp.getArguments()){ + gatherMemEntry(arg); + args.push_back(arg); + } + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { if (!hasWeightAlways(getGlobalOp)) { auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (globalMemrefOp.getName().starts_with("arg")){ + StringRef indexStr = globalMemrefOp.getName().substr(4); + int index = 0; + llvm::to_integer(indexStr,index, 10); + globalAliases.push_back({getGlobalOp.getResult(), args[index]}); + } auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); if (inserted) gatherMemEntry(getGlobalOp.getResult()); @@ -66,8 +82,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { } }); - for (mlir::Value arg : funcOp.getArguments()) - gatherMemEntry(arg); funcOp.walk([&](memref::AllocOp allocOp) { if (!allocOp->getParentOfType()) @@ -133,6 +147,12 @@ json::Object PimCodeGen::createEmptyOffset() { return offset; } +size_t PimCodeGen::remapCoreId(size_t coreId) const { + auto it = emittedCoreIds.find(coreId); + assert(it != emittedCoreIds.end() && "Missing emitted core id remapping"); + return it->second; +} + static json::Object createRs1OnlyOffset() { json::Object offset; offset["offset_select"] = 1; @@ -192,7 +212,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t json::Object json; json["op"] = opName; json["rd"] = 0; - json["core"] = coreId; + json["core"] = remapCoreId(coreId); json["size"] = size; json["offset"] = createEmptyOffset(); emitInstruction(std::move(json)); @@ -414,6 +434,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa emitInstruction(std::move(json)); } +void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const { +} + void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); @@ -583,6 +606,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor return scalarCore; } +static void aliasMaterializedHostGlobals( + ModuleOp moduleOp, func::FuncOp funcOp, pim::PimCoreOp coreOp, PimAcceleratorMemory& memory) { + coreOp.walk([&](memref::GetGlobalOp getGlobalOp) { + if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult())) + return; + + auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!targetGlobal) + return; + + mlir::Value aliasedValue; + funcOp.walk([&](memref::GetGlobalOp candidate) { + if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult())) + return; + if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal) + aliasedValue = candidate.getResult(); + }); + + if (aliasedValue) + memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue]; + }); +} + /// Write global constant data into a binary memory image at their allocated addresses. static OnnxMlirCompilerErrorCodes writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { @@ -677,6 +723,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); else if (auto vsoftmaxOp = dyn_cast(op)) coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); + else if (auto getGlobalOp = dyn_cast(op)) + coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); else { op.emitError("Unsupported codegen for this operation"); op.dump(); @@ -880,13 +928,14 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { /// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses). static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, PimAcceleratorMemory& memory, - size_t coreCount, + size_t maxCoreId, json::Object xbarsPerArrayGroup, StringRef outputDirPath) { json::Object configJson; - // +1 because pimsim-nn also considers the host as a core - configJson["core_cnt"] = coreCount + 1; + // pimsim-nn indexes cores directly by their numeric core ID, with the host + // occupying core 0. + configJson["core_cnt"] = maxCoreId + 1; // TODO: Should this be based on the floating point type used in the model? // The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision @@ -960,12 +1009,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: // For each core, specify the number of crossbar per array group. // This implementation always assigns one crossbar per group. json::Object xbarsPerArrayGroup; - size_t coreCount = 0; + size_t maxCoreId = 0; // Create Weight Folder auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); + llvm::DenseMap emittedCoreIds; + size_t nextEmittedCoreId = 1; + + for (Operation* op : coreLikeOps) { + if (auto coreOp = dyn_cast(op)) { + size_t originalCoreId = static_cast(coreOp.getCoreId()); + if (!emittedCoreIds.contains(originalCoreId)) + emittedCoreIds[originalCoreId] = nextEmittedCoreId++; + continue; + } + + auto coreBatchOp = cast(op); + auto batchCoreIds = getBatchCoreIds(coreBatchOp); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) { + size_t originalCoreId = static_cast(batchCoreIds[lane]); + if (!emittedCoreIds.contains(originalCoreId)) + emittedCoreIds[originalCoreId] = nextEmittedCoreId++; + } + } for (Operation* op : coreLikeOps) { SmallVector scalarCores; @@ -979,8 +1047,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } for (pim::PimCoreOp coreOp : scalarCores) { - auto coreId = coreOp.getCoreId(); - coreCount++; + size_t originalCoreId = static_cast(coreOp.getCoreId()); + size_t coreId = emittedCoreIds.lookup(originalCoreId); + maxCoreId = std::max(maxCoreId, coreId); std::error_code errorCode; auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; @@ -991,7 +1060,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } coreFileStream << '['; - PimCodeGen coreCodeGen(memory, coreFileStream); + PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds); + aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory); memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); @@ -1009,7 +1079,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: return InvalidOutputFileAccess; } - auto& mapWeightToFile = mapCoreWeightToFileName[static_cast(coreId)]; + auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId]; json::Array xbarsPerGroup; for (unsigned index : getUsedWeightIndices(coreOp)) { if (index >= coreOp.getWeights().size()) { @@ -1037,5 +1107,5 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: coreOp.erase(); } - return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath); + return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index bad6b55..38e2c3f 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -1,5 +1,6 @@ #pragma once +#include "llvm/ADT/DenseMap.h" #include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm/Support/JSON.h" @@ -58,10 +59,12 @@ public: class PimCodeGen { PimAcceleratorMemory& memory; llvm::raw_fd_ostream& coreFileStream; + const llvm::DenseMap& emittedCoreIds; size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { return memory.getValueAddress(value, knowledge); } + size_t remapCoreId(size_t coreId) const; static llvm::json::Object createEmptyOffset(); void emitInstruction(llvm::json::Object instruction) const; @@ -83,8 +86,10 @@ class PimCodeGen { void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; public: - PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson) - : memory(memory), coreFileStream(coreJson) {} + PimCodeGen(PimAcceleratorMemory& memory, + llvm::raw_fd_ostream& coreJson, + const llvm::DenseMap& emittedCoreIds) + : memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {} void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; @@ -106,6 +111,7 @@ public: void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; + void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; }; diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 90d91fd..20fc99e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -11,6 +12,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -183,6 +185,7 @@ void ONNXToSpatialPass::runOnOperation() { llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); + encapsulateGlobalInstruction(*entryFunc); if (failed(promoteConstantInputsToWeights(*entryFunc))) { @@ -199,19 +202,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); - if (isa_and_present(source.getDefiningOp())) { - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - mapper.map(source, BB->getArgument(0)); - auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; - } + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; + } + return false; +} + +bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) { + if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present(inst)) { + auto source = toRemoveOp.getSource(); + rewriter.setInsertionPointAfter(toRemoveOp); + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; } return false; } @@ -245,6 +265,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { inst->erase(); return true; } + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); + SmallVector sourceTypes; + SmallVector sourceLoc; + for (auto source : sources) { + sourceTypes.push_back(source.getType()); + sourceLoc.push_back(loc); + } + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) + mapper.map(source, bbArg); + auto newConcat = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; } return false; } @@ -306,6 +344,89 @@ static FailureOr materializeWeightLikeValueInBlock(Value value, IRRewrite return cast(mapped); } +bool sourceOpernadHasWeightAlways(Operation* op) { + if (op == nullptr) + return false; + + Operation* source = nullptr; + do { + + if (isa(*op)) { + return false; + } + else if (auto extractSliceOp = dyn_cast(*op)) { + auto tmpSource = extractSliceOp.getSource(); + auto definingOp = tmpSource.getDefiningOp(); + if (definingOp) + op = definingOp; + else + return false; + } + else if (auto extractRowsOp = dyn_cast(*op)) { + auto tmpSource = extractRowsOp.getInput(); + auto definingOp = tmpSource.getDefiningOp(); + if (definingOp) + op = definingOp; + else + return false; + } + else if (auto expandShapeOp = dyn_cast(*op)) { + auto tmpSource = expandShapeOp.getSrc(); + auto definingOp = tmpSource.getDefiningOp(); + if (definingOp) + op = definingOp; + else + return false; + } + else if (auto transposeOp = dyn_cast(*op)) { + auto tmpSource = transposeOp.getData(); + auto definingOp = tmpSource.getDefiningOp(); + if (definingOp) + op = definingOp; + else + return false; + } + else if (auto collapseShapeOp = dyn_cast(*op)) { + auto tmpSource = collapseShapeOp.getSrc(); + auto definingOp = tmpSource.getDefiningOp(); + if (definingOp) + op = definingOp; + else + return false; + } + else if (auto constantOp = dyn_cast(*op)) { + source = constantOp; + } + else if (auto concatOp = dyn_cast(*op)) { + bool res = false; + for (auto operand : concatOp.getOperands()) { + res |= hasWeightAlways(operand.getDefiningOp()); + if (res) + return res; + } + return res; + } + else if (auto concatOp = dyn_cast(*op)) { + bool res = false; + for (auto operand : concatOp.getOperands()) { + res |= hasWeightAlways(operand.getDefiningOp()); + if (res) + return res; + } + return res; + } + else { + op->dump(); + llvm_unreachable("Global instruction not handle in func"); + } + } + while (source == nullptr); + + if (hasWeightAlways(source)) + return true; + return false; +} + // TODO what we want to keep in global? void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); @@ -314,8 +435,14 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - keep |= encapsulator( - rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); + + if (isa( + instruction) + || isa(instruction) + || sourceOpernadHasWeightAlways(&instruction)) + continue; + + keep |= encapsulateSlice(rewriter, loc, &instruction); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index 8ed7fcd..a9ba74a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -23,7 +23,10 @@ static Value extractSliceAt( sizes.push_back(rewriter.getIndexAttr(dim)); offsets[axis] = rewriter.getIndexAttr(offset); sizes[axis] = rewriter.getIndexAttr(size); - return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); + SmallVector resultShape(inputType.getShape()); + resultShape[axis] = size; + auto resultType = RankedTensorType::get(resultShape, inputType.getElementType()); + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides); } struct Split : OpConversionPattern { @@ -49,12 +52,7 @@ struct Split : OpConversionPattern { if (!resultType || !resultType.hasStaticShape()) return failure(); int64_t sliceSize = resultType.getShape()[axis]; - auto computeOp = - createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) { - Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc()); - spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output); - }); - outputs.push_back(computeOp.getResult(0)); + outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); offset += sliceSize; } diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index d8222c8..351d96d 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen) add_pim_library(OMSpatialToPim SpatialToPimPass.cpp Common.cpp + Patterns.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp new file mode 100644 index 0000000..ec513ff --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -0,0 +1,385 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include "Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct MoveExtractSliceIntoCompute final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override { + Location loc = extractSliceOp.getLoc(); + + if (!isa(extractSliceOp->getParentOp())) + return failure(); + + for (auto& uses : extractSliceOp->getUses()) { + if (isa(uses.getOwner())) { + auto spatCompute = cast(uses.getOwner()); + if (spatCompute.getInputs().empty()) + return failure(); + if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex()) + return failure(); + } + else if (isa_and_present(uses.getOwner()->getParentOp())) { + return failure(); + } + } + + llvm::DenseMap mapSpatToExtract; + + for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) { + + if (auto spatCompute = dyn_cast(uses.getOwner())) { + auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + + if (BBArgValue.use_empty()) + continue; + + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + if (!mapSpatToExtract.contains(spatCompute.getOperation())) { + auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); + mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)}); + } + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = dyn_cast(uses.getOwner())) { + auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); + + if (BBArgValue.use_empty()) + continue; + + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { + auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); + mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)}); + } + + rewriter.startOpModification(spatComputeBatch.getOperation()); + BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]); + spatComputeBatch.getInputsMutable().erase(BBArgIndex); + spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + else { + { + if (auto spatCompute = uses.getOwner()->getParentOfType()) { + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + if (!mapSpatToExtract.contains(spatCompute.getOperation())) { + auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); + mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)}); + } + + rewriter.startOpModification(spatCompute.getOperation()); + uses.set(mapSpatToExtract[spatCompute.getOperation()]); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = uses.getOwner()->getParentOfType()) { + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { + auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); + mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)}); + } + + rewriter.startOpModification(spatComputeBatch.getOperation()); + uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + } + } + } + + rewriter.eraseOp(extractSliceOp); + return success(); + } +}; + +struct ArithConstToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { + static int i = 0; + Location loc = constantOp.getLoc(); + + if (hasWeightAlways(constantOp)) + return failure(); + + if (!isa(constantOp->getParentOp())) + return failure(); + + if (llvm::all_of(constantOp->getUsers(), [](Operation* op) { + if (isa(op)) + return false; + if (isa(op->getParentOp())) + return true; + return false; + })) + return failure(); + + rewriter.setInsertionPoint(constantOp->getParentOfType()); + + auto constRankedTensorType = llvm::dyn_cast(constantOp.getType()); + + if (constRankedTensorType) { + mlir::MemRefType memRefType = + mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); + std::string argName = "const_" + std::to_string(i++); + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + constantOp.getValueAttr(), + rewriter.getUnitAttr(), + {}); + + llvm::DenseMap mapSpatComputeToConst; + + for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); + } + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); + } + + rewriter.startOpModification(spatComputeBatch.getOperation()); + BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]); + spatComputeBatch.getInputsMutable().erase(BBArgIndex); + spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + else { + { + + if (auto spatCompute = constUses.getOwner()->getParentOfType()) { + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); + } + + rewriter.startOpModification(spatCompute.getOperation()); + constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType()) { + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); + } + + rewriter.startOpModification(spatComputeBatch.getOperation()); + constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + } + } + } + } + else if (constantOp.getType().isIntOrIndexOrFloat()) { + llvm::DenseMap mapSpatComputeToConst; + + for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(newConst->getResult(0)); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + + rewriter.startOpModification(spatComputeBatch.getOperation()); + BBArgValue.replaceAllUsesWith(newConst->getResult(0)); + spatComputeBatch.getInputsMutable().erase(BBArgIndex); + spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + else { + if (auto parent = constUsers->getParentOfType()) { + if (!mapSpatComputeToConst.contains(parent)) { + rewriter.setInsertionPoint(&parent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)}); + } + constUses.set(mapSpatComputeToConst[parent.getOperation()]); + } + else { + auto batchParent = constUsers->getParentOfType(); + assert(batchParent && "Global Constant used direcly not within a compute"); + if (!mapSpatComputeToConst.contains(batchParent.getOperation())) { + rewriter.setInsertionPoint(&batchParent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)}); + } + constUses.set(mapSpatComputeToConst[batchParent.getOperation()]); + } + } + } + } + auto parent = constantOp->getParentOp(); + rewriter.eraseOp(constantOp); + return success(); + } +}; + +struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override { + + if (funcOp.getArguments().empty()) + return failure(); + + if (llvm::all_of(funcOp.getArguments(), + [](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); })) + return failure(); + + Location loc = funcOp.getLoc(); + + for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) { + if (arg.getUses().empty()) + continue; + + rewriter.setInsertionPoint(funcOp.getOperation()); + + assert(isa(arg.getType())); + + auto argRankedTensorType = llvm::dyn_cast(arg.getType()); + mlir::MemRefType memRefType = + mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); + + std::string argName = "arg_" + std::to_string(index); + + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); + + for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { + auto argUser = argUses.getOwner(); + if (auto spatCompute = dyn_cast(argUser)) { + auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else if (auto spatComputeBatch = dyn_cast(argUser)) { + auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + rewriter.startOpModification(spatComputeBatch.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatComputeBatch.getInputsMutable().erase(BBArgIndex); + spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + } + else { + rewriter.setInsertionPoint(argUser); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + rewriter.startOpModification(argUser); + argUses.set(toTensor); + rewriter.finalizeOpModification(argUser); + } + } + } + + return success(); + } +}; + +} // namespace +void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { + patterns.add( + patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp new file mode 100644 index 0000000..e34f6ab --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + + +namespace onnx_mlir { + +void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns); + +} diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ed3454b..cc56f01 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,20 +1,26 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -24,6 +30,7 @@ #include #include "Conversion/ONNXToSpatial/Common.hpp" +#include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -53,7 +60,7 @@ struct SpatialToPimPass : PassWrapper> void runOnOperation() final; private: - SmallVector outputTensors; + SmallVector> outputTensors; size_t coreId = 0; SmallVector operationsToRemove; @@ -179,7 +186,22 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { - auto inputType = cast(extractRowsOp.getInput().getType()); + Value input = extractRowsOp.getInput(); + RankedTensorType inputType; + if (auto tensorType = dyn_cast(input.getType())) { + inputType = tensorType; + } + else if (auto memRefType = dyn_cast(input.getType())) { + inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType()); + rewriter.setInsertionPoint(extractRowsOp); + input = bufferization::ToTensorOp::create( + rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr()) + .getResult(); + } + else { + extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering"); + return; + } int64_t numCols = inputType.getDimSize(1); SmallVector replacements; @@ -187,11 +209,16 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite rewriter.setInsertionPoint(extractRowsOp); for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { + auto outputType = dyn_cast(output.getType()); + if (!outputType) { + extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering"); + return; + } SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto rowSlice = tensor::ExtractSliceOp::create( - rewriter, extractRowsOp.getLoc(), cast(output.getType()), extractRowsOp.getInput(), offsets, sizes, strides); + rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); replacements.push_back(rowSlice.getResult()); } @@ -205,6 +232,75 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.replaceOp(concatOp, concatenated); } +static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, + SmallVectorImpl& helperChain, + bool requireReturnUse = true) { + if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) + return failure(); + if (requireReturnUse + && (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin()))) + return failure(); + + Block& block = computeOp.getBody().front(); + if (block.getNumArguments() != 1) + return failure(); + + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return failure(); + + SmallVector reverseChain; + Value currentValue = yieldOp.getOperands().front(); + Value blockArg = block.getArgument(0); + + while (currentValue != blockArg) { + Operation* definingOp = currentValue.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp)) + return failure(); + reverseChain.push_back(definingOp); + currentValue = definingOp->getOperand(0); + } + + SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); + for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) + if (!chainSet.contains(&op) + && !isa(op)) + return failure(); + + helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); + return success(); +} + +static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { + if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) + return false; + if (!llvm::all_of(computeOp.getResult(0).getUsers(), + [](Operation* user) { return isa(user); })) + return false; + + Block& block = computeOp.getBody().front(); + if (block.getNumArguments() != 0) + return false; + + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return false; + + rewriter.setInsertionPoint(computeOp); + IRMapping mapping; + for (Operation& op : block.without_terminator()) { + cloneMappedHelperOperands(&op, mapping, rewriter); + Operation* clonedOp = rewriter.clone(op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + rewriter.setInsertionPointAfter(clonedOp); + } + + Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); + computeOp.getResult(0).replaceAllUsesWith(replacement); + return true; +} + struct ReturnUseInfo { size_t returnIndex; SmallVector helperChain; @@ -295,6 +391,20 @@ static std::optional analyzeConcatReturnUse(Value value) { } SmallVector helperChain; + if (auto helperCompute = dyn_cast(currentUser)) { + if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) + return std::nullopt; + + if (failed(collectHelperComputeChain(helperCompute, helperChain))) + return std::nullopt; + + currentValue = helperCompute.getResult(0); + auto currentUses = currentValue.getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentUser = currentUses.begin()->getOwner(); + } + while (isChannelUseChainOp(currentUser)) { helperChain.push_back(currentUser); auto currentUses = currentUser->getResult(0).getUses(); @@ -419,21 +529,22 @@ static void cloneHelperChain(Value sourceValue, } } -static void emitHostCopy(IRRewriter& rewriter, - Location loc, - Value outputTensor, - Value sourceValue, - int32_t hostTargetOffset, - int32_t deviceSourceOffset, - int32_t sizeInBytes) { - PimMemCopyDevToHostOp::create(rewriter, - loc, - outputTensor.getType(), - outputTensor, - sourceValue, - rewriter.getI32IntegerAttr(hostTargetOffset), - rewriter.getI32IntegerAttr(deviceSourceOffset), - rewriter.getI32IntegerAttr(sizeInBytes)); +static Value emitHostCopy(IRRewriter& rewriter, + Location loc, + Value outputTensor, + Value sourceValue, + int32_t hostTargetOffset, + int32_t deviceSourceOffset, + int32_t sizeInBytes) { + return PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + sourceValue, + rewriter.getI32IntegerAttr(hostTargetOffset), + rewriter.getI32IntegerAttr(deviceSourceOffset), + rewriter.getI32IntegerAttr(sizeInBytes)) + .getOutput(); } void SpatialToPimPass::runOnOperation() { @@ -458,12 +569,21 @@ void SpatialToPimPass::runOnOperation() { scf::SCFDialect, BuiltinDialect>(); - RewritePatternSet patterns(ctx); - populateWithGenerated(patterns); + { + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { - signalPassFailure(); - return; + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } + + { + RewritePatternSet patterns(ctx); + populateGlobalTensorToMemrefPatterns(patterns); + + walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto returnOp = cast(funcOp.front().getTerminator()); @@ -489,7 +609,8 @@ void SpatialToPimPass::runOnOperation() { } SmallVector receiveOps; - funcOp.walk([&](spatial::SpatChannelReceiveOp op) { receiveOps.push_back(op); }); + for (auto op : funcOp.getOps()) + receiveOps.push_back(op); for (auto receiveOp : receiveOps) { bool onlyPendingRemovalUsers = llvm::all_of( receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); @@ -505,22 +626,26 @@ void SpatialToPimPass::runOnOperation() { } SmallVector receiveManyOps; - funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { receiveManyOps.push_back(op); }); + for (auto op : funcOp.getOps()) + receiveManyOps.push_back(op); for (auto receiveManyOp : receiveManyOps) lowerChannelReceiveMany(receiveManyOp, rewriter); SmallVector sendOps; - funcOp.walk([&](spatial::SpatChannelSendOp op) { sendOps.push_back(op); }); + for (auto op : funcOp.getOps()) + sendOps.push_back(op); for (auto sendOp : sendOps) lowerChannelSend(sendOp, rewriter); SmallVector sendManyOps; - funcOp.walk([&](spatial::SpatChannelSendManyOp op) { sendManyOps.push_back(op); }); + for (auto op : funcOp.getOps()) + sendManyOps.push_back(op); for (auto sendManyOp : sendManyOps) lowerChannelSendMany(sendManyOp, rewriter); SmallVector extractRowsOps; - funcOp.walk([&](spatial::SpatExtractRowsOp op) { extractRowsOps.push_back(op); }); + for (auto op : funcOp.getOps()) + extractRowsOps.push_back(op); for (auto extractRowsOp : extractRowsOps) lowerExtractRows(extractRowsOp, rewriter); @@ -560,6 +685,36 @@ void SpatialToPimPass::runOnOperation() { assert(false && "tracked op removal reached a cycle or missed dependency"); } + SmallVector remainingConcatOps; + funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); }); + for (auto concatOp : remainingConcatOps) + lowerConcat(concatOp, rewriter); + + SmallVector remainingReceiveOps; + funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); }); + for (auto receiveOp : remainingReceiveOps) + lowerChannelReceive(receiveOp, rewriter); + + SmallVector remainingReceiveManyOps; + funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); }); + for (auto receiveManyOp : remainingReceiveManyOps) + lowerChannelReceiveMany(receiveManyOp, rewriter); + + SmallVector remainingSendOps; + funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); }); + for (auto sendOp : remainingSendOps) + lowerChannelSend(sendOp, rewriter); + + SmallVector remainingSendManyOps; + funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); }); + for (auto sendManyOp : remainingSendManyOps) + lowerChannelSendMany(sendManyOp, rewriter); + + SmallVector remainingExtractRowsOps; + funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); }); + for (auto extractRowsOp : remainingExtractRowsOps) + lowerExtractRows(extractRowsOp, rewriter); + // Dump to file for debug bool hasSpatialOps = false; moduleOp.walk([&](Operation* op) { @@ -579,6 +734,13 @@ void SpatialToPimPass::runOnOperation() { void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); + if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter)) + return; + + SmallVector helperChain; + if (succeeded(collectHelperComputeChain(computeOp, helperChain))) + return; + auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); @@ -616,9 +778,9 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto storedType = cast(storedValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; - Value outputTensor = outputTensors[returnUse->returnIndex]; if (auto storedOp = storedValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); + Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, @@ -637,8 +799,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; - Value outputTensor = outputTensors[resultIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, @@ -654,13 +816,13 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter } if (auto concatReturnUse = analyzeConcatReturnUse(result)) { - Value outputTensor = outputTensors[concatReturnUse->returnIndex]; - auto outputType = cast(outputTensor.getType()); size_t elementSize = yieldType.getElementTypeBitWidth() / 8; if (concatReturnUse->helperChain.empty()) { - int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); + auto outputType = cast(outputTensor.getType()); + int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); emitHostCopy(rewriter, loc, outputTensor, @@ -671,7 +833,15 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter continue; } - auto storedType = cast(yieldValue.getType()); + auto storedType = dyn_cast(yieldValue.getType()); + if (!storedType) { + computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); + signalPassFailure(); + return; + } + rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); + auto outputType = cast(outputTensor.getType()); for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); for (auto [dim, idx] : llvm::enumerate(sourceIndices)) @@ -701,19 +871,18 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto scalarTensorType = RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); - rewriter.setInsertionPointAfterValue(yieldValue); auto elementSlice = tensor::ExtractSliceOp::create( rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); rewriter.setInsertionPointAfter(elementSlice); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); - emitHostCopy(rewriter, - loc, - outputTensor, - elementSlice.getResult(), - static_cast(destinationFlatOffset * elementSize), - 0, - static_cast(elementSize)); + outputTensor = emitHostCopy(rewriter, + loc, + outputTensor, + elementSlice.getResult(), + static_cast(destinationFlatOffset * elementSize), + 0, + static_cast(elementSize)); } continue; } @@ -848,6 +1017,26 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc continue; } + if (auto toTensorOp = dyn_cast(op)) { + if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { + Operation* cloned = rewriter.clone(op, mapper); + auto clonedTensor = cloned->getResult(0); + auto clonedType = cast(clonedTensor.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + clonedTensor, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, clonedTensor)) + .getOutput(); + mapper.map(toTensorOp.getResult(), copied); + continue; + } + } + for (Value operand : op.getOperands()) { if (!isa(operand.getType()) || mapper.contains(operand)) continue; @@ -922,17 +1111,33 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); - rewriter.setInsertionPointToStart(returnOp->getBlock()); - for (auto returnValue : returnOp->getOperands()) { + for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { Operation* returnValueDefiningOp = returnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); - outputTensors.push_back(returnValue); + outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; }); } else { - auto newOutputTensor = - createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); - outputTensors.push_back(newOutputTensor); + auto outRankedTensorType = llvm::dyn_cast(returnValue.getType()); + auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); + + std::string outputName = "output_" + std::to_string(index); + rewriter.setInsertionPoint(returnOp.getParentOp()); + memref::GlobalOp::create(rewriter, + returnOp.getLoc(), + rewriter.getStringAttr(outputName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); + outputTensors.push_back( + [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + return toTensor.getResult(); + }); } } } @@ -940,11 +1145,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); - auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { - auto tensorType = cast(valueToReplace.getType()); + auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { + auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; - rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); + rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); @@ -953,86 +1158,27 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu loc, tensorType, deviceTensor, - hostTensor, + inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); - rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); + rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; - // Replace input tensors with memRefs - SmallVector inputTensors; - for (size_t i = 0; i < funcOp.getNumArguments(); i++) { - BlockArgument tensorArg = funcOp.getArgument(i); - DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); - ShapedType tensorArgType = cast(tensorArg.getType()); - MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); - - if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc))) - return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering"); - BlockArgument memRefArg = funcOp.getArgument(i + 1); - - Block& block = funcOp.getBody().front(); - rewriter.setInsertionPoint(&block.front()); - auto toTensorOp = - bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); - inputTensors.push_back(toTensorOp); - - tensorArg.replaceAllUsesWith(toTensorOp); - if (failed(funcOp.eraseArgument(i))) - return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering"); - } - - llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { - unsigned numComputeWeights = computeOp.getWeights().size(); - for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { - TypedValue tensorSource; - int64_t elementsOffset = 0; - - if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { - tensorSource = cast>(sliceOp.getSource()); - - if (isa(tensorSource.getDefiningOp())) - continue; - - ArrayRef sourceShape = tensorSource.getType().getShape(); - ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); - ArrayRef sliceSizes = sliceOp.getStaticSizes(); - ArrayRef sliceStrides = sliceOp.getStaticStrides(); - assert("Extracting slice non-contiguous in memory" - && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); - - for (size_t i = 0; i < sliceOffsets.size(); i++) { - int64_t partialOffset = sliceOffsets[i]; - if (partialOffset != 0) - for (size_t j = i + 1; j < sourceShape.size(); j++) - partialOffset *= sourceShape[j]; - elementsOffset += partialOffset; - } - - computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); - sliceOpsToRemove.insert(sliceOp); + if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) + continue; + for (auto getGlobal : computeOp.getOps()) { + if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) { + assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); + auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); + insertMemCopyHostToDev(toTensorOpValue, 0); } - else - tensorSource = cast>(computeOpInput); - - // Values already produced inside the device-side graph must not be - // copied back through a host-to-device staging step here. - if (isa(tensorSource.getDefiningOp())) - continue; - - BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); - insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); } } - for (auto sliceOp : sliceOpsToRemove) - if (sliceOp->getUses().empty()) - rewriter.eraseOp(sliceOp); - return success(); } @@ -1050,7 +1196,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = - isa(onlyUser) || isChannelUseChainOp(onlyUser); + isa(onlyUser) || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) return; @@ -1062,6 +1208,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri return; } + if (auto computeOp = dyn_cast(op)) { + markOpToRemove(computeOp); + for (Value input : computeOp.getInputs()) + markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); + return; + } + if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getOperands()) @@ -1070,12 +1223,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri }; SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); + auto loc = returnOp.getLoc(); for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); Operation* returnOperand = it.value().getDefiningOp(); - - rewriter.modifyOpInPlace(returnOp, - [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); + rewriter.setInsertionPoint(returnOp); + Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc); + rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); markOwnedReturnChain(returnOperand, markOwnedReturnChain); } } diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index f203f5f..7a2b7d8 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -24,7 +24,7 @@ def PimTensor : // Execution //===----------------------------------------------------------------------===// -def PimCoreOp : PimOp<"core", [SingleBlock]> { +def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { let summary = "Execute a block on a PIM core"; let regions = (region SizedRegion<1>:$body); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index cb5b9b2..866453f 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -3,12 +3,17 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Threading.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + #include "Common/PimCommon.hpp" #include "Compiler/PimCodeGen.hpp" #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" @@ -40,14 +45,44 @@ private: void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); + // Refactor this into a function + { + auto funcOp = getPimEntryFunc(moduleOp); - // One-Shot-Bufferization - bufferization::OneShotBufferizationOptions options; - options.allowUnknownOps = true; - bufferization::BufferizationState state; - if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { - moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); - signalPassFailure(); + auto coreOps = llvm::to_vector(funcOp->getOps()); + MLIRContext* ctx = moduleOp.getContext(); + // failableParallelForEach will run the lambda in parallel and stop if any thread fails + LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) { + // Again, allocate state LOCALLY per thread/function + bufferization::OneShotBufferizationOptions options; + options.allowUnknownOps = true; + bufferization::BufferizationState state; + if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) { + coreOp.emitError("Failed to bufferize PIM and Spatial ops"); + return failure(); + } + return success(); + }); + + if (failed(result)) { + moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops"); + signalPassFailure(); + } + + funcOp->walk([&](bufferization::ToTensorOp toTensorOp) { + if (llvm::isa_and_present(toTensorOp->getParentOp())) + toTensorOp->setAttr("restrict", UnitAttr::get(ctx)); + }); + + // One-Shot-Bufferization + bufferization::OneShotBufferizationOptions options; + options.allowUnknownOps = true; + bufferization::BufferizationState state; + + if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { + moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); + signalPassFailure(); + } } MLIRContext* ctx = moduleOp.getContext(); @@ -57,7 +92,18 @@ void PimBufferizationPass::runOnOperation() { RewritePatternSet patterns(ctx); populateWithGenerated(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + // Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies. + // Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering. + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + bool hasFailed = false; + moduleOp.walk([&](Operation* op) { + if (!isa(op)) + return WalkResult::advance(); + if (failed(applyPartialConversion(op, target, frozenPatterns))) + hasFailed = true; + return WalkResult::skip(); + }); + if (hasFailed) { signalPassFailure(); return; } diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp index 4c39ac1..a69a8a2 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp @@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill"); OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(coreOp); - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); rewriter.setInsertionPoint(mapOp); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; pim::PimMemCopyOp::create(rewriter, mapOp.getLoc(), @@ -258,9 +257,18 @@ struct FoldConstantTransposePattern final : OpRewritePattern(); - if (!sourceGetGlobal) - return failure(); + if (!sourceGetGlobal) { + memcpHd = transposeOp.getInput().getDefiningOp(); + if (!memcpHd) + return failure(); + sourceGetGlobal = memcpHd.getHostSource().getDefiningOp(); + if (!sourceGetGlobal) + return failure(); + } auto moduleOp = transposeOp->getParentOfType(); if (!moduleOp) @@ -298,13 +306,26 @@ struct FoldConstantTransposePattern final : OpRewritePatterngetUsers().empty() - && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa(user); }); + && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { + return isa(user); + }); if (isAlwaysWeight) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); } + auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp(); rewriter.replaceOp(transposeOp, newGetGlobal.getResult()); + + if (memcpHd && memcpHd.use_empty()) { + auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp(); + rewriter.eraseOp(memcpHd); + if (deviceAllocOp && deviceAllocOp->use_empty()) + rewriter.eraseOp(deviceAllocOp); + } + if (outputAllocOp && outputAllocOp->use_empty()) + rewriter.eraseOp(outputAllocOp); + return success(); } }; @@ -341,18 +362,25 @@ struct FoldConstantAllocPattern final : OpRewritePattern { continue; } - if (!isa(user)) + if (!isa(user)) return failure(); } if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { - return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa(user); }); + return llvm::all_of(castOp->getUsers(), [](Operation* user) { + return isa(user); + }); })) { allLiveUsersAreCoreOps = false; } if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) { - return isa(user); + return isa(user); })) { return failure(); } @@ -389,6 +417,83 @@ struct FoldConstantAllocPattern final : OpRewritePattern { } }; +struct FoldConstantHostCopyPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override { + if (copyOp->getParentOfType()) + return failure(); + + auto allocOp = copyOp.getTarget().getDefiningOp(); + if (!allocOp) + return failure(); + auto allocType = dyn_cast(allocOp.getType()); + if (!allocType || !allocType.hasStaticShape()) + return failure(); + + auto srcSubview = getStaticSubviewInfo(copyOp.getSource()); + Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource()); + + auto moduleOp = copyOp->getParentOfType(); + if (!moduleOp) + return failure(); + + auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); + if (failed(denseAttr)) + return failure(); + + DenseElementsAttr foldedAttr; + if (succeeded(srcSubview)) { + if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + auto staticOffsets = getStaticSubviewOffsets(*srcSubview); + if (failed(staticOffsets)) + return failure(); + + auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape()); + if (failed(maybeFoldedAttr)) + return failure(); + foldedAttr = *maybeFoldedAttr; + } + else { + auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); + if (resultTensorType != denseAttr->getType()) + return failure(); + foldedAttr = *denseAttr; + } + + bool allLiveUsersAreCores = true; + for (Operation* user : allocOp->getUsers()) { + if (user == copyOp) + continue; + if (isa(user)) + continue; + if (isa(user)) + continue; + if (isa(user)) { + allLiveUsersAreCores = false; + continue; + } + return failure(); + } + + auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy"); + if (allLiveUsersAreCores) + markWeightAlways(newGlobal); + + rewriter.setInsertionPoint(allocOp); + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); + if (allLiveUsersAreCores) + markWeightAlways(newGetGlobal); + + rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult()); + rewriter.eraseOp(copyOp); + if (allocOp.use_empty()) + rewriter.eraseOp(allocOp); + return success(); + } +}; + struct FoldConstantMemCpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -443,7 +548,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { continue; if (isa(user)) continue; - if (isa(user)) + if (isa(user)) continue; if (isa(user)) { allLiveUsersAreCores = false; @@ -473,7 +578,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { patterns - .add( + .add( patterns.getContext()); } diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 3f53abb..d91ca76 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -23,7 +23,27 @@ static bool isAddressOnlyHostOp(Operation* op) { memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, - memref::ExpandShapeOp>(op); + memref::ExpandShapeOp, + memref::CopyOp>(op); +} + +// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity. +// Used for memref.copy operands which may be non-contiguous subviews. +static bool isBaseAddressableValue(Value value) { + while (true) { + if (isa(value)) + return true; + Operation* defOp = value.getDefiningOp(); + if (!defOp) + return false; + if (isa(defOp)) + return true; + if (auto subview = dyn_cast(defOp)) { value = subview.getSource(); continue; } + if (auto cast = dyn_cast(defOp)) { value = cast.getSource(); continue; } + if (auto collapse = dyn_cast(defOp)) { value = collapse.getSrc(); continue; } + if (auto expand = dyn_cast(defOp)) { value = expand.getSrc(); continue; } + return false; + } } static bool isCodegenAddressableValue(Value value) { @@ -183,6 +203,13 @@ private: return verifyAddressOnlySource(op, collapseOp.getSrc()); if (auto expandOp = dyn_cast(op)) return verifyAddressOnlySource(op, expandOp.getSrc()); + if (auto copyOp = dyn_cast(op)) { + if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) { + op->emitOpError("depends on a value that is not backed by addressable storage"); + return failure(); + } + return success(); + } return success(); } diff --git a/validation/validate_one.py b/validation/validate_one.py index e75a687..4cc4ea9 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -37,7 +37,7 @@ class ValidationResult: class ProgressReporter: - def __init__(self, total_models, stages_per_model=STAGE_COUNT): + def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None): self.total_models = total_models self.stages_per_model = stages_per_model self.total_steps = max(1, total_models * stages_per_model) @@ -45,7 +45,7 @@ class ProgressReporter: self.passed_models = 0 self.failed_models = 0 self.current_label = "" - self.enabled = True + self.enabled = sys.stdout.isatty() if enabled is None else enabled self.columns = shutil.get_terminal_size((100, 20)).columns self.suspended = False