diff --git a/README.md b/README.md index 8507a2d..e6aeae2 100644 --- a/README.md +++ b/README.md @@ -168,8 +168,8 @@ Each validation run writes artifacts in the model workspace, for example under The compiler currently dumps dialect snapshots such as `spatial0.mlir`, `spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`, -`pim2_coalesced.mlir`, `pim3_folded.mlir`, and -`pim4_materialized.mlir` when an output directory is available. +`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is +available. To rerun the simulator manually with tracing after validation has produced a `raptor/pim/` directory: diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index 0bd0305..31588b2 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -123,7 +123,6 @@ add_pim_library(OMPIMAccel OMPimBufferization OMPimMemoryCoalescing OMPimHostConstantFolding - OMPimHostConstantMaterialization OMPimVerification MLIRTensorInferTypeOpInterfaceImpl ) diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 0662b85..395396f 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) { return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); } +llvm::FailureOr> getStaticMemRefTypeStrides(mlir::MemRefType type) { + llvm::SmallVector strides; + int64_t offset = 0; + if (failed(type.getStridesAndOffset(strides, offset))) + return mlir::failure(); + if (llvm::is_contained(strides, mlir::ShapedType::kDynamic)) + return mlir::failure(); + return strides; +} + template bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); @@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static mlir::Value current = weight; while (true) { + if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) { + current = directAlias; + continue; + } + if (auto defOp = current.getDefiningOp()) { if (auto getGlobalOp = mlir::dyn_cast(defOp)) { auto moduleOp = weightOwner ? weightOwner->getParentOfType() : mlir::ModuleOp {}; @@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static CompiledIndexExpr offsetExpr = makeConstantExpr(0); for (mlir::Operation* viewOp : llvm::reverse(viewOps)) { if (auto subview = mlir::dyn_cast(viewOp)) { - llvm::SmallVector nextStrides; - nextStrides.reserve(subview.getMixedOffsets().size()); for (auto [offset, stride, sourceStride] : llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) { CompiledIndexExpr offsetValue = makeConstantExpr(0); @@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static return mlir::failure(); } offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride)); - nextStrides.push_back(stride * sourceStride); } - view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); - view.strides = std::move(nextStrides); + auto resultType = mlir::cast(subview.getResult().getType()); + auto resultStrides = getStaticMemRefTypeStrides(resultType); + if (failed(resultStrides)) + return mlir::failure(); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = std::move(*resultStrides); continue; } if (auto collapse = mlir::dyn_cast(viewOp)) { - if (view.strides != computeRowMajorStrides(view.shape)) - return mlir::failure(); auto resultType = mlir::cast(collapse.getResult().getType()); + auto resultStrides = getStaticMemRefTypeStrides(resultType); + if (failed(resultStrides)) + return mlir::failure(); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); - view.strides = computeRowMajorStrides(view.shape); + view.strides = std::move(*resultStrides); continue; } if (auto expand = mlir::dyn_cast(viewOp)) { - if (view.strides != computeRowMajorStrides(view.shape)) - return mlir::failure(); auto resultType = mlir::cast(expand.getResult().getType()); + auto resultStrides = getStaticMemRefTypeStrides(resultType); + if (failed(resultStrides)) + return mlir::failure(); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); - view.strides = computeRowMajorStrides(view.shape); + view.strides = std::move(*resultStrides); + continue; } + + if (auto castOp = mlir::dyn_cast(viewOp)) { + auto resultType = mlir::cast(castOp.getResult().getType()); + auto resultStrides = getStaticMemRefTypeStrides(resultType); + if (failed(resultStrides)) + return mlir::failure(); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = std::move(*resultStrides); + continue; + } + + return mlir::failure(); } auto resolvedOffset = offsetExpr.evaluate(knowledge); @@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static return view; } - if (mlir::isa(defOp)) { + if (auto subview = mlir::dyn_cast(defOp)) { viewOps.push_back(defOp); - if (auto subview = mlir::dyn_cast(defOp)) - current = subview.getSource(); - else if (auto collapse = mlir::dyn_cast(defOp)) - current = collapse.getSrc(); - else - current = mlir::cast(defOp).getSrc(); + current = subview.getSource(); + continue; + } + + if (auto collapse = mlir::dyn_cast(defOp)) { + viewOps.push_back(defOp); + current = collapse.getSrc(); + continue; + } + + if (auto expand = mlir::dyn_cast(defOp)) { + viewOps.push_back(defOp); + current = expand.getSrc(); continue; } if (auto castOp = mlir::dyn_cast(defOp)) { + viewOps.push_back(defOp); current = castOp.getSource(); continue; } @@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static return mlir::failure(); } + if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) { + current = loopAlias; + continue; + } + auto weightIndex = resolveWeightIndex(weightOwner, current); if (!weightIndex) return mlir::failure(); diff --git a/src/PIM/Common/Support/Diagnostics.hpp b/src/PIM/Common/Support/Diagnostics.hpp index 11e3d78..2eeebf7 100644 --- a/src/PIM/Common/Support/Diagnostics.hpp +++ b/src/PIM/Common/Support/Diagnostics.hpp @@ -28,6 +28,8 @@ struct CappedDiagnosticReporter { op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription; } + void noteFailures(int64_t count) { numFailures += count; } + bool hasFailure() const { return numFailures != 0; } private: diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt index 5e66728..7d6bcc7 100644 --- a/src/PIM/Compiler/CMakeLists.txt +++ b/src/PIM/Compiler/CMakeLists.txt @@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils OMPimBufferization OMPimMemoryCoalescing OMPimHostConstantFolding - OMPimHostConstantMaterialization OMPimVerification OMPimPasses OMONNXToSpatial diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index cb99ffd..669ff57 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -32,6 +32,7 @@ #include "Common/IR/CompactAsmUtils.hpp" #include "Common/PimCommon.hpp" +#include "Common/Support/Diagnostics.hpp" #include "Common/Support/CheckedArithmetic.hpp" #include "Common/Support/ReportUtils.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" @@ -996,12 +997,44 @@ static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { } struct CoreEmissionResult { + static constexpr size_t kMaxStoredCodegenDiagnostics = 8; + + struct DiagnosticRecord { + Operation* op = nullptr; + std::string message; + }; + OnnxMlirCompilerErrorCodes status = CompilerSuccess; MemoryReportRow reportRow; llvm::SmallVector usedWeights; MemoryPlanArtifacts livenessArtifacts; + llvm::SmallVector diagnostics; + size_t diagnosticCount = 0; + + void recordDiagnostic(Operation* op, StringRef message) { + ++diagnosticCount; + if (diagnostics.size() < kMaxStoredCodegenDiagnostics) + diagnostics.push_back({op, message.str()}); + } }; +static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) { + StaticValueKnowledge knowledge; + for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) + knowledge.aliases[coreOp.getWeightArgument(index)] = weight; + return knowledge; +} + +static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) { + StaticValueKnowledge knowledge; + knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane; + for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights())) + knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight; + for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs())) + knowledge.aliases[coreBatchOp.getInputArgument(index)] = input; + return knowledge; +} + template class ScopedMapBindings { using KeyTy = typename MapTy::key_type; @@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: const StaticValueKnowledge& knowledge) -> llvm::FailureOr { auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge); if (failed(weightView)) { - vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen"); + std::string message; + llvm::raw_string_ostream os(message); + os << "requires a statically resolvable dense global weight view during PIM codegen; weight=" + << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType(); + result.recordDiagnostic(vmmOp, os.str()); + return failure(); + } + if (weightView->shape.size() != 2) { + std::string message; + llvm::raw_string_ostream os(message); + os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=["; + llvm::interleaveComma(weightView->shape, os); + os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType(); + result.recordDiagnostic(vmmOp, os.str()); return failure(); } @@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); deviceMemory.allocateCore(coreOp); - int64_t processedOperations = codeGenCoreOps( - coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot); + StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp); + int64_t processedOperations = + codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot); if (processedOperations < 0) { result.status = CompilerFailure; return result; } - assert(processedOperations > 0); result.reportRow = deviceMemory.getReportRow(); result.usedWeights = std::move(usedWeights); result.livenessArtifacts = deviceMemory.getLivenessArtifacts(); @@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); for (unsigned lane : job.lanes) { - StaticValueKnowledge knowledge; - knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane; - for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i) - knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i]; + StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane); deviceMemory.allocateCore(coreBatchOp, lane); coreCodeGen.setBatchLane(lane); @@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: result.status = CompilerFailure; return result; } - assert(processedOperations > 0); } result.reportRow = deviceMemory.getReportRow(); @@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: mlir::parallelFor( moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); }); + pim::CappedDiagnosticReporter diagnostics; + Operation* summaryAnchor = nullptr; + for (const CoreEmissionResult& result : jobResults) { + if (!summaryAnchor && !result.diagnostics.empty()) + summaryAnchor = result.diagnostics.front().op; + for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) { + diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; }); + } + size_t unreportedCount = result.diagnosticCount - result.diagnostics.size(); + diagnostics.noteFailures(static_cast(unreportedCount)); + } + if (diagnostics.hasFailure()) + diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(), + "PIM codegen diagnostic(s)"); + for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) if (jobResults[jobIndex].status != CompilerSuccess) return jobResults[jobIndex].status; diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 2d111cc..5035379 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitPimCodegen) { pm.addPass(createPimHostConstantFoldingPass()); pm.addPass(createMessagePass("Pim host constants folded")); - pm.addPass(createPimMaterializeHostConstantsPass()); pm.addPass(createPimMemoryCoalescingPass()); pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 11dccf4..c61ffb1 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -375,6 +375,57 @@ static void cloneHelperChain(Value sourceValue, } } +static bool isHostStaticReturnValue(Value value) { + llvm::SmallPtrSet visited; + while (Operation* definingOp = value.getDefiningOp()) { + if (!visited.insert(definingOp).second) + return false; + if (isa(definingOp) || definingOp->hasTrait()) + return true; + if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1) + return false; + value = definingOp->getOperand(0); + } + return false; +} + +static FailureOr +materializeHostStaticReturnValue(IRRewriter& rewriter, Value value, OperationFolder& constantFolder) { + llvm::SmallVector chain; + llvm::SmallPtrSet visited; + while (Operation* definingOp = value.getDefiningOp()) { + if (!visited.insert(definingOp).second) + return failure(); + chain.push_back(definingOp); + if (isa(definingOp) || definingOp->hasTrait()) + break; + if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1) + return failure(); + value = definingOp->getOperand(0); + } + + if (chain.empty()) + return failure(); + + IRMapping mapping; + Value clonedValue; + for (Operation* op : llvm::reverse(chain)) { + if (auto constantOp = dyn_cast(op)) { + clonedValue = getOrCreateConstantLike(constantFolder, constantOp); + mapping.map(op->getResult(0), clonedValue); + continue; + } + + Operation* clonedOp = rewriter.clone(*op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + clonedValue = clonedOp->getResult(0); + rewriter.setInsertionPointAfter(clonedOp); + } + + return clonedValue; +} + static FailureOr emitHostCopy(IRRewriter& rewriter, Location loc, Value outputTensor, @@ -444,7 +495,30 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low OperationFolder constantFolder(producerOp->getContext()); auto storedTensorType = cast(storedValue.getType()); + auto materializeDirectHostReturn = [&](size_t returnIndex, + Value sourceValue, + ArrayRef helperChain) -> ReturnPathLoweringResult { + rewriter.setInsertionPointAfter(producerOp); + auto hostStaticValue = materializeHostStaticReturnValue(rewriter, sourceValue, constantFolder); + if (failed(hostStaticValue)) + return ReturnPathLoweringResult::Failure; + + Value hostReturnValue = *hostStaticValue; + if (!helperChain.empty()) + cloneHelperChain(hostReturnValue, helperChain, rewriter, constantFolder, hostReturnValue); + + outputTensors[returnIndex] = + [hostReturnValue](IRRewriter& rewriter, Location loc) -> Value { return hostReturnValue; }; + return ReturnPathLoweringResult::Handled; + }; + if (auto returnUse = analyzeReturnUse(producedValue)) { + if (isHostStaticReturnValue(storedValue)) { + for (Operation* op : returnUse->helperChain) + markOpToRemove(op); + return materializeDirectHostReturn(returnUse->returnIndex, storedValue, returnUse->helperChain); + } + Value currentStoredValue = storedValue; cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue); for (Operation* op : returnUse->helperChain) @@ -470,6 +544,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); + if (isHostStaticReturnValue(storedValue)) + return materializeDirectHostReturn(resultIndexInReturn, storedValue, {}); auto byteSize = pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size"); if (failed(byteSize)) diff --git a/src/PIM/Dialect/Pim/CMakeLists.txt b/src/PIM/Dialect/Pim/CMakeLists.txt index 0e3d18d..6eebe9a 100644 --- a/src/PIM/Dialect/Pim/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/CMakeLists.txt @@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td) add_subdirectory(Transforms/Bufferization) add_subdirectory(Transforms/MemoryCoalescing) add_subdirectory(Transforms/HostConstantFolding) -add_subdirectory(Transforms/HostConstantMaterialization) add_subdirectory(Transforms/Verification) add_pim_library(PimOps diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index de22489..b553766 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -6,6 +6,7 @@ #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp" using namespace mlir; using namespace bufferization; @@ -13,7 +14,9 @@ using namespace bufferization; namespace onnx_mlir::pim { FailureOr materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) + bool isContiguous = + succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)); + if (isContiguous && isDeviceLocalPimAddress(memrefValue)) return memrefValue; auto shapedType = cast(memrefValue.getType()); @@ -29,13 +32,21 @@ FailureOr materializeContiguousInputMemRef(Value memrefValue, Location lo if (failed(sizeAttr)) return failure(); + if (isHostBackedPimAddress(memrefValue)) { + return PimMemCopyHostToDevOp::create( + rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) + .getOutput(); + } + return PimMemCopyOp::create( rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) .getOutput(); } Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) + bool isContiguous = + succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)); + if (isContiguous && isDeviceLocalPimAddress(memrefValue)) return memrefValue; auto shapedType = cast(memrefValue.getType()); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp index 321347e..129078d 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp @@ -1,9 +1,70 @@ #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; +static bool isCoreBatchInputArgument(Value value) { + auto blockArg = dyn_cast(value); + if (!blockArg) + return false; + + auto coreBatchOp = dyn_cast_or_null(blockArg.getOwner()->getParentOp()); + if (!coreBatchOp) + return false; + + unsigned firstInputArg = 1 + coreBatchOp.getWeights().size(); + return static_cast(blockArg.getArgNumber()) >= firstInputArg; +} + +static FailureOr getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) { + llvm::SmallPtrSet visited; + while (value && visited.insert(value).second) { + Value alias = resolveLoopCarriedAlias(value, knowledge); + if (alias) + value = alias; + + if (auto aliased = knowledge.aliases.lookup(value)) { + value = aliased; + continue; + } + + if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base)) + return base; + + if (isa(value)) + return value; + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return value; + + if (auto subviewOp = dyn_cast(definingOp)) { + value = subviewOp.getSource(); + continue; + } + if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + + return value; + } + + if (value) + return value; + return failure(); +} + FailureOr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) { auto type = mlir::cast(memref.getType()); auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size"); @@ -11,3 +72,40 @@ FailureOr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build return failure(); return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size"); } + +FailureOr onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) { + Value alias = resolveLoopCarriedAlias(value, knowledge); + if (alias) + value = alias; + + auto resolved = resolveContiguousAddress(value, knowledge); + if (succeeded(resolved)) + return resolved->base; + + auto compiled = compileContiguousAddressExpr(value); + if (failed(compiled)) { + if (isa(value)) + return value; + return failure(); + } + return compiled->base; +} + +bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) { + auto base = getPimStorageBase(value, knowledge); + if (failed(base)) + return false; + + if (isCoreBatchInputArgument(*base)) + return true; + + return isa_and_nonnull(base->getDefiningOp()); +} + +bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) { + auto base = getPimStorageBase(value, knowledge); + if (failed(base)) + return false; + + return isa_and_nonnull(base->getDefiningOp()); +} diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp index 6816d62..37b1295 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp @@ -2,11 +2,19 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" + namespace onnx_mlir { namespace pim { mlir::FailureOr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref); +mlir::FailureOr getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {}); + +bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}); + +bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}); + } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 01e35c4..bb8e35b 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -8,6 +8,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "Common/PimCommon.hpp" @@ -15,6 +16,7 @@ #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" +#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" @@ -27,24 +29,71 @@ namespace onnx_mlir { namespace { -struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct MemRefCopyWorkItem { + memref::CopyOp copyOp; + StaticValueKnowledge knowledge; +}; - LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override { - if (!copyOp->getParentOfType() && !copyOp->getParentOfType()) - return failure(); +static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) { + StaticValueKnowledge knowledge; + for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) + knowledge.aliases[coreOp.getWeightArgument(index)] = weight; + return knowledge; +} - auto sourceType = dyn_cast(copyOp.getSource().getType()); - auto targetType = dyn_cast(copyOp.getTarget().getType()); - if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape()) - return failure(); - if (sourceType.getElementType() != targetType.getElementType()) - return failure(); +static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) { + StaticValueKnowledge knowledge; + knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane; + for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights())) + knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight; + for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs())) + knowledge.aliases[coreBatchOp.getInputArgument(index)] = input; + return knowledge; +} - Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); - auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource()); - if (failed(sizeAttr)) - return failure(); +static LogicalResult +lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) { + if (!copyOp->getParentOfType() && !copyOp->getParentOfType()) + return failure(); + + auto sourceType = dyn_cast(copyOp.getSource().getType()); + auto targetType = dyn_cast(copyOp.getTarget().getType()); + if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape()) + return failure(); + if (sourceType.getElementType() != targetType.getElementType()) + return failure(); + + Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); + auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource()); + if (failed(sizeAttr)) + return failure(); + + bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge); + bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge); + bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge); + bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge); + + if (targetIsDevice && sourceIsHost) { + pim::PimMemCopyHostToDevOp::create(rewriter, + copyOp.getLoc(), + copyOp.getTarget().getType(), + zeroOffset, + zeroOffset, + copyOp.getTarget(), + copyOp.getSource(), + *sizeAttr); + } + else if (targetIsHost && sourceIsDevice) { + pim::PimMemCopyDevToHostOp::create(rewriter, + copyOp.getLoc(), + copyOp.getTarget().getType(), + zeroOffset, + zeroOffset, + copyOp.getTarget(), + copyOp.getSource(), + *sizeAttr); + } + else if (targetIsDevice && sourceIsDevice) { pim::PimMemCopyOp::create(rewriter, copyOp.getLoc(), copyOp.getTarget().getType(), @@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern { copyOp.getTarget(), copyOp.getSource(), *sizeAttr); - rewriter.eraseOp(copyOp); - return success(); } -}; + else { + copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource() + << " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost + << " device=" << sourceIsDevice << ", target=" << copyOp.getTarget() + << " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost + << " device=" << targetIsDevice; + return failure(); + } + + rewriter.eraseOp(copyOp); + return success(); +} struct PimBufferizationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) @@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() { } MLIRContext* ctx = moduleOp.getContext(); - RewritePatternSet memrefCopyPatterns(ctx); - memrefCopyPatterns.add(ctx); - FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns)); - PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns); - memrefCopyApplicator.applyDefaultCostModel(); PatternRewriter rewriter(ctx); - SmallVector copyWorklist; - moduleOp.walk([&](memref::CopyOp copyOp) { - if (copyOp->getParentOfType() || copyOp->getParentOfType()) - copyWorklist.push_back(copyOp); + SmallVector copyWorklist; + llvm::SmallPtrSet seenCopyOps; + auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) { + if (seenCopyOps.insert(copyOp.getOperation()).second) + copyWorklist.push_back({copyOp, knowledge}); + }; + + moduleOp.walk([&](pim::PimCoreOp coreOp) { + StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp); + (void) walkPimCoreBlockStructurally( + coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) { + if (auto copyOp = dyn_cast(&op)) + addCopyOp(copyOp, opKnowledge); + return success(); + }); + }); + moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { + llvm::SmallVector lanes; + lanes.push_back(0); + if (coreBatchOp.getLaneCount() > 1) + lanes.push_back(static_cast(coreBatchOp.getLaneCount() - 1)); + for (unsigned lane : lanes) { + StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane); + (void) walkPimCoreBlockStructurally( + coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) { + if (auto copyOp = dyn_cast(&op)) + addCopyOp(copyOp, opKnowledge); + return success(); + }); + } }); bool hasFailed = false; - for (memref::CopyOp copyOp : copyWorklist) { - if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) { - copyOp.emitOpError("failed to lower memref.copy inside PIM core body"); + for (const MemRefCopyWorkItem& workItem : copyWorklist) { + memref::CopyOp copyOp = workItem.copyOp; + rewriter.setInsertionPoint(copyOp); + if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge))) hasFailed = true; - } } if (hasFailed) { signalPassFailure(); diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp index 2fa900f..fd870b8 100644 --- a/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Dialect/Pim/Transforms/HostConstantFolding/Patterns/Constant.cpp @@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size"); if (failed(sizeAttr)) return failure(); - pim::PimMemCopyOp::create( + pim::PimMemCopyHostToDevOp::create( rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr); rewriter.eraseOp(mapOp); return success(); diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt deleted file mode 100644 index c2105c2..0000000 --- a/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_pim_library(OMPimHostConstantMaterialization - MaterializeHostConstantsPass.cpp - - EXCLUDE_FROM_OM_LIBS - - LINK_LIBS PUBLIC - OMPimCommon - PimOps -) diff --git a/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp b/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp deleted file mode 100644 index 97850a3..0000000 --- a/src/PIM/Dialect/Pim/Transforms/HostConstantMaterialization/MaterializeHostConstantsPass.cpp +++ /dev/null @@ -1,161 +0,0 @@ -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/MathExtras.h" - -#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -namespace { - -template -static void materializeHostConstantsInCore(CoreOpTy coreOp, - IRRewriter& rewriter, - OperationFolder& constantFolder, - bool& hasFailure) { - DenseMap>> materializedValues; - DominanceInfo dominance(coreOp); - SmallVector ops; - coreOp.getBody().front().walk([&](Operation* op) { - if (!isa(op)) - ops.push_back(op); - }); - - for (Operation* op : ops) { - if (auto loadOp = dyn_cast(op); loadOp && loadOp.getType().isIndex()) - continue; - - for (OpOperand& operand : op->getOpOperands()) { - Value originalValue = operand.get(); - if (!isa(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber())) - continue; - - auto resolvedAddress = resolveContiguousAddress(originalValue); - if (failed(resolvedAddress)) - continue; - - auto getGlobalOp = dyn_cast_or_null(resolvedAddress->base.getDefiningOp()); - if (!getGlobalOp) - continue; - - auto originalType = dyn_cast(originalValue.getType()); - if (!originalType || !originalType.hasStaticShape()) { - op->emitOpError("host constant materialization requires a static memref operand"); - hasFailure = true; - continue; - } - - auto& cachedByOffset = materializedValues[resolvedAddress->base]; - auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; - auto cachedValue = cachedByType.find(originalType); - if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) { - operand.set(cachedValue->second); - continue; - } - - auto type = dyn_cast(originalValue.getType()); - auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size") - : FailureOr(failure()); - auto totalBytesAttr = - succeeded(totalBytes) - ? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size") - : FailureOr(failure()); - if (failed(totalBytesAttr) - || failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) { - hasFailure = true; - continue; - } - - auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType()); - - rewriter.setInsertionPoint(op); - Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType); - Value deviceDst = localAlloc; - if (contiguousType != originalType) - deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc); - - Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0); - Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset); - Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter, - op->getLoc(), - originalType, - zeroOffset, - hostOffset, - deviceDst, - getGlobalOp.getResult(), - *totalBytesAttr) - .getOutput(); - - cachedByType[originalType] = copiedValue; - operand.set(copiedValue); - } - } -} - -struct MaterializeHostConstantsPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) - - StringRef getArgument() const override { return "materialize-pim-host-constants"; } - StringRef getDescription() const override { - return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops"; - } - - void runOnOperation() override { - ModuleOp moduleOp = getOperation(); - IRRewriter rewriter(moduleOp.getContext()); - OperationFolder constantFolder(moduleOp.getContext()); - bool hasFailure = false; - - for (func::FuncOp funcOp : moduleOp.getOps()) { - if (funcOp.isExternal()) - continue; - - for (pim::PimCoreOp coreOp : funcOp.getOps()) - materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure); - - for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps()) - materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure); - - SmallVector hostCompactOps; - for (Operation& op : funcOp.getBody().front()) - if (isa(op)) - hostCompactOps.push_back(&op); - - for (Operation* op : hostCompactOps) { - rewriter.setInsertionPoint(op); - auto concatOp = cast(op); - concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization"); - hasFailure = true; - } - } - - if (hasFailure) { - moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above"); - signalPassFailure(); - return; - } - - dumpModule(moduleOp, "pim4_materialized"); - } -}; - -} // namespace - -std::unique_ptr createPimMaterializeHostConstantsPass() { - return std::make_unique(); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp index 68969d6..eb07f1c 100644 --- a/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Verification/VerificationPass.cpp @@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) { || isa(compiledAddress->base.getDefiningOp()); } -static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) { - auto resolvedAddress = resolveContiguousAddress(value, knowledge); - if (succeeded(resolvedAddress)) - return isa(resolvedAddress->base) - || isa(resolvedAddress->base.getDefiningOp()); - - auto compiledAddress = compileContiguousAddressExpr(value); - if (failed(compiledAddress)) - return false; - return isa(compiledAddress->base) - || isa(compiledAddress->base.getDefiningOp()); -} - static bool isConstantGlobalView(Value value) { while (true) { Operation* defOp = value.getDefiningOp(); @@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) { memref::GetGlobalOp>(op); } +static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) { + auto resolvedAddress = resolveContiguousAddress(value, knowledge); + Value base; + if (succeeded(resolvedAddress)) { + base = resolvedAddress->base; + } + else { + auto compiledAddress = compileContiguousAddressExpr(value); + if (failed(compiledAddress)) + return false; + base = compiledAddress->base; + } + + if (isa(base)) + return true; + return isa_and_nonnull(base.getDefiningOp()); +} + struct VerificationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) @@ -311,10 +316,10 @@ private: } if (isExplicitHostMemCopyOperand(&op, operandIndex)) { - if (!isCodegenAddressableValue(operand, knowledge)) { + if (!isHostAddressableValue(operand, knowledge)) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << "host operand #" << operandIndex - << " is not backed by contiguous addressable storage"; + << " must be backed by host-addressable storage"; }); hasFailure = true; } diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index fe817ba..515f459 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -21,8 +21,6 @@ std::unique_ptr createMergeComputeNodesPass(); std::unique_ptr createPimHostConstantFoldingPass(); -std::unique_ptr createPimMaterializeHostConstantsPass(); - std::unique_ptr createPimVerificationPass(); std::unique_ptr createEmitPimCodePass(); diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 98ab342..5afc09a 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -78,7 +78,6 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createPimMemoryCoalescingPass); registerPass(createMergeComputeNodesPass); registerPass(createPimHostConstantFoldingPass); - registerPass(createPimMaterializeHostConstantsPass); registerPass(createPimVerificationPass); registerPass(createEmitPimCodePass); } diff --git a/validation/raptor.py b/validation/raptor.py index 7f3d6f7..939bcec 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -11,7 +11,6 @@ PIM_PASS_LABELS = ( ("SpatialToPimPass", "Spatial to PIM"), ("PimBufferizationPass", "Bufferize PIM"), ("HostConstantFoldingPass", "Fold Host Constants"), - ("MaterializeHostConstantsPass", "Materialize Host Constants"), ("VerificationPass", "Verify PIM"), ("EmitPimCodePass", "Emit PIM Code"), )