diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 397874b..458f521 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -7,8 +7,8 @@ #include "mlir/IR/Matchers.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -152,8 +152,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter, } // namespace -LogicalResult -lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { +LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, + IRRewriter& rewriter) { Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); auto oldYield = dyn_cast(oldBlock.getTerminator()); @@ -167,7 +167,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& "resultful compute_batch lowering currently requires a spat.in_parallel terminator"); } - SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); + SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector batchInputs; if (!computeBatchOp.getInputs().empty()) @@ -192,7 +192,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& return computeBatchOp.emitOpError( "resultful compute_batch lowering currently requires each result to be used directly by func.return"); - hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc); + hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc); result.replaceAllUsesWith(hostOutputTensors[resultIndex]); } } diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp deleted file mode 100644 index 3afc4b0..0000000 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" - -namespace onnx_mlir { - -mlir::LogicalResult -lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index bf33989..989dfea 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -6,9 +6,9 @@ #include "mlir/IR/Matchers.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -148,15 +148,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute } // namespace -void markOpToRemove(CoreLoweringState& state, Operation* op) { - if (!llvm::is_contained(state.operationsToRemove, op)) - state.operationsToRemove.push_back(op); -} - -LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) { +LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp, + IRRewriter& rewriter, + OperationFolder& constantFolder) { Location loc = computeOp->getLoc(); - if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder)) + if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder)) return success(); SmallVector helperChain; @@ -179,7 +176,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) .getOutput(); blockArg.replaceAllUsesWith(received); - markOpToRemove(state, receiveOp); + markOpToRemove(receiveOp); continue; } @@ -200,7 +197,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) .getOutput(); blockArg.replaceAllUsesWith(received); - markOpToRemove(state, receiveTensorOp); + markOpToRemove(receiveTensorOp); } } @@ -211,9 +208,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& if (result.use_empty()) continue; - ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove}; ReturnPathLoweringResult returnPathResult = - lowerComputeResultReturnPath(computeOp, cast(result), yieldValue, returnPathState, rewriter); + lowerComputeResultReturnPath(computeOp, cast(result), yieldValue, rewriter); if (returnPathResult == ReturnPathLoweringResult::Failure) return failure(); if (returnPathResult == ReturnPathLoweringResult::Handled) @@ -240,7 +236,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& auto coreOp = PimCoreOp::create(rewriter, loc, ValueRange(computeWeights), - rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId))); + rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); rewriter.setInsertionPointToStart(&block); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { @@ -249,7 +245,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& continue; if (auto constantOp = input.getDefiningOp()) { - blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder)); + blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder)); continue; } @@ -261,8 +257,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& PimMemCopyHostToDevOp::create(rewriter, loc, outputBuffer.getType(), - getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), - getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), + getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), + getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), outputBuffer, input, getTensorSizeInBytesAttr(rewriter, input)) diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp deleted file mode 100644 index 7e7d214..0000000 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/FoldUtils.h" - -#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -namespace onnx_mlir { - -struct CoreLoweringState { - size_t& nextCoreId; - llvm::SmallVectorImpl& outputTensors; - llvm::SmallVectorImpl& operationsToRemove; - mlir::OperationFolder& constantFolder; -}; - -void markOpToRemove(CoreLoweringState& state, mlir::Operation* op); - -mlir::LogicalResult -lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 681bf1f..82a7aef 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -9,9 +9,9 @@ #include "mlir/Transforms/FoldUtils.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -44,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) { pim::PimTransposeOp>(op); } -static void markOpToRemove(ReturnPathState& state, Operation* op) { - if (!llvm::is_contained(state.operationsToRemove, op)) - state.operationsToRemove.push_back(op); -} - static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) { std::string name = baseName.str(); unsigned suffix = 0; @@ -390,9 +385,7 @@ static Value emitHostCopy(IRRewriter& rewriter, } // namespace -void addReturnOutputBuffers(func::ReturnOp returnOp, - IRRewriter& rewriter, - SmallVectorImpl& outputTensors) { +void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { Value currentReturnValue = returnValue; @@ -427,8 +420,8 @@ void addReturnOutputBuffers(func::ReturnOp returnOp, } } -ReturnPathLoweringResult lowerProducedValueReturnPath( - Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) { +raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath( + Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) { Location loc = producerOp->getLoc(); OperationFolder constantFolder(producerOp->getContext()); auto storedTensorType = cast(storedValue.getType()); @@ -437,13 +430,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath( Value currentStoredValue = storedValue; cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue); for (Operation* op : returnUse->helperChain) - markOpToRemove(state, op); + markOpToRemove(op); auto storedType = cast(currentStoredValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; if (auto storedOp = currentStoredValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); - Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc); + Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, @@ -464,7 +457,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath( size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfterValue(storedValue); - Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc); + Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, @@ -480,11 +473,11 @@ ReturnPathLoweringResult lowerProducedValueReturnPath( if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8; for (Operation* concatOp : concatReturnUse->concatChain) - markOpToRemove(state, concatOp); + markOpToRemove(concatOp); if (concatReturnUse->helperChain.empty()) { rewriter.setInsertionPointAfterValue(storedValue); - Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); + Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); emitHostCopy(rewriter, @@ -505,7 +498,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath( return ReturnPathLoweringResult::Failure; } rewriter.setInsertionPointAfterValue(storedValue); - Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); + Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); @@ -553,12 +546,15 @@ ReturnPathLoweringResult lowerProducedValueReturnPath( return ReturnPathLoweringResult::NotReturnPath; } -ReturnPathLoweringResult lowerComputeResultReturnPath( - spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { - return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter); +raptor::SpatialToPimPass::ReturnPathLoweringResult +raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp, + OpResult result, + Value yieldValue, + IRRewriter& rewriter) { + return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter); } -void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) { +void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { if (!op) return; @@ -575,13 +571,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite if (isReturnHelperChainOp(op)) { Value source = op->getOperand(0); - markOpToRemove(state, op); + markOpToRemove(op); markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); return; } if (auto computeOp = dyn_cast(op)) { - markOpToRemove(state, computeOp); + markOpToRemove(computeOp); if (!computeOp.getInputs().empty()) for (Value input : computeOp.getInputs()) markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); @@ -589,33 +585,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite } if (auto concatOp = dyn_cast(op)) { - markOpToRemove(state, concatOp); + markOpToRemove(concatOp); for (Value operand : concatOp.getOperands()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { - markOpToRemove(state, concatOp); + markOpToRemove(concatOp); for (Value operand : concatOp.getInputs()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { - markOpToRemove(state, concatOp); + markOpToRemove(concatOp); for (Value operand : concatOp.getInputs()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } if (auto receiveOp = dyn_cast(op)) { - markOpToRemove(state, receiveOp); + markOpToRemove(receiveOp); return; } if (auto receiveTensorOp = dyn_cast(op)) - markOpToRemove(state, receiveTensorOp); + markOpToRemove(receiveTensorOp); }; SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); @@ -624,7 +620,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite size_t orderWithinReturn = it.index(); Operation* returnOperand = it.value().getDefiningOp(); rewriter.setInsertionPoint(returnOp); - Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc); + Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); markOwnedReturnChain(returnOperand, markOwnedReturnChain); } diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp deleted file mode 100644 index fe86724..0000000 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" - -#include - -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -namespace onnx_mlir { - -using OutputTensorFactory = std::function; - -struct ReturnPathState { - llvm::SmallVectorImpl& outputTensors; - llvm::SmallVectorImpl& operationsToRemove; -}; - -enum class ReturnPathLoweringResult { - Handled, - NotReturnPath, - Failure -}; - -void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, - mlir::IRRewriter& rewriter, - llvm::SmallVectorImpl& outputTensors); - -ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp, - mlir::OpResult result, - mlir::Value yieldValue, - ReturnPathState& state, - mlir::IRRewriter& rewriter); - -ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp, - mlir::Value producedValue, - mlir::Value storedValue, - ReturnPathState& state, - mlir::IRRewriter& rewriter); - -void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index a8807a7..5130722 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -23,54 +23,28 @@ #include #include +#include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Pass/PIMPasses.h" +#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" +#include "Conversion/SpatialToPim/Common.hpp" +#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" +#include "Conversion/SpatialToPim/PhaseVerification.hpp" +#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp" +#include "Dialect/Pim/PimOps.hpp" +#include "Dialect/Spatial/SpatialOps.hpp" +#include "Pass/PIMPasses.h" +#include "SpatialToPimPass.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { - -namespace { +namespace raptor { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" -struct SpatialToPimPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) - StringRef getArgument() const override { return "convert-spatial-to-pim"; } - StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } - - SpatialToPimPass() = default; - SpatialToPimPass(const SpatialToPimPass& pass) {} - - void runOnOperation() final; - -private: - SmallVector outputTensors; - size_t coreId = 0; - SmallVector operationsToRemove; - - LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); - - void markOpToRemove(Operation* op); - void eraseOpsToRemove(); - - void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); -}; - -} // namespace +} // namespace raptor static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType(); @@ -150,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); } -void SpatialToPimPass::runOnOperation() { +void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { coreId = 0; + outputTensors.clear(); + operationsToRemove.clear(); ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); @@ -197,18 +173,16 @@ void SpatialToPimPass::runOnOperation() { walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); auto returnOp = cast(funcOp.front().getTerminator()); - addReturnOutputBuffers(returnOp, rewriter, outputTensors); - ReturnPathState returnPathState {outputTensors, operationsToRemove}; + addReturnOutputBuffers(returnOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); signalPassFailure(); return; } - CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder}; for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); - if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { + if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) { computeOp.emitOpError("failed to lower spat.compute to pim.core"); signalPassFailure(); return; @@ -217,7 +191,7 @@ void SpatialToPimPass::runOnOperation() { for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); - if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { + if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) { computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); signalPassFailure(); return; @@ -266,7 +240,7 @@ void SpatialToPimPass::runOnOperation() { } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); - replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); + replaceReturnWithOutputBuffers(returnOp, rewriter); eraseOpsToRemove(); RewritePatternSet finalTensorPackingPatterns(ctx); @@ -309,7 +283,7 @@ void SpatialToPimPass::runOnOperation() { dumpModule(moduleOp, "pim0"); } -void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { +void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { OperationFolder constantFolder(funcOp.getContext()); funcOp.walk([&](PimVMMOp vmmOp) { auto outputType = cast(vmmOp.getOutput().getType()); @@ -343,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I }); } -LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { +LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, + IRRewriter& rewriter) { Location loc = funcOp.getLoc(); OperationFolder constantFolder(funcOp.getContext()); @@ -387,18 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu return success(); } -void SpatialToPimPass::markOpToRemove(Operation* op) { +void raptor::SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } -void SpatialToPimPass::eraseOpsToRemove() { +void raptor::SpatialToPimPass::eraseOpsToRemove() { for (Operation* op : operationsToRemove) { op->dropAllUses(); op->erase(); } } -std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } +std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp new file mode 100644 index 0000000..7e17fd0 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" + +#include "llvm/ADT/StringRef.h" + +#include + +#include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir { +namespace raptor { + +struct SpatialToPimPass : mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) + llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; } + llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } + + SpatialToPimPass() = default; + SpatialToPimPass(const SpatialToPimPass& pass) {} + + void runOnOperation() final; + +private: + using OutputTensorFactory = std::function; + + llvm::SmallVector outputTensors; + size_t coreId = 0; + llvm::SmallVector operationsToRemove; + + mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, + mlir::IRRewriter& rewriter); + mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, + mlir::IRRewriter& rewriter, + mlir::OperationFolder& constantFolder); + mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, + mlir::IRRewriter& rewriter); + + enum class ReturnPathLoweringResult { + Handled, + NotReturnPath, + Failure + }; + + void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); + ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp, + mlir::OpResult result, + mlir::Value yieldValue, + mlir::IRRewriter& rewriter); + ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp, + mlir::Value producedValue, + mlir::Value storedValue, + mlir::IRRewriter& rewriter); + void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); + + void markOpToRemove(mlir::Operation* op); + void eraseOpsToRemove(); + + void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); +}; + +} // namespace raptor + +} // namespace onnx_mlir