#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include #include #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" struct SpatialToPimPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: SmallVector> outputTensors; size_t coreId = 0; SmallVector operationsToRemove; void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; } // namespace static bool isChannelUseChainOp(Operation* op) { return isa(op); } static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { for (Value operand : op->getOperands()) { if (mapping.lookupOrNull(operand)) continue; Operation* definingOp = operand.getDefiningOp(); if (!definingOp) continue; if (!isa(definingOp)) continue; Operation* clonedOp = rewriter.clone(*definingOp, mapping); for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); rewriter.setInsertionPointAfter(clonedOp); } } static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return static_cast(spatialCoreIdAttr.getInt()); return static_cast(fallbackCoreId++); } static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); SmallVector coreIds; coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) coreIds.push_back(static_cast(fallbackCoreId++)); return coreIds; } static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId())); rewriter.setInsertionPoint(sendOp); PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); rewriter.eraseOp(sendOp); } static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); return; } auto outputType = cast(receiveOp.getResult().getType()); rewriter.setInsertionPoint(receiveOp); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received = PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); rewriter.replaceOp(receiveOp, received); } static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(sendManyOp); for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input); auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)); PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr); } rewriter.eraseOp(sendManyOp); } static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { SmallVector replacements; replacements.reserve(receiveManyOp.getNumResults()); rewriter.setInsertionPoint(receiveManyOp); for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { auto outputType = cast(output.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)); Value received = PimReceiveOp::create( rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); replacements.push_back(received); } rewriter.replaceOp(receiveManyOp, ValueRange(replacements)); } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { Value input = extractRowsOp.getInput(); RankedTensorType inputType; if (auto tensorType = dyn_cast(input.getType())) { inputType = tensorType; } else if (auto memRefType = dyn_cast(input.getType())) { inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType()); rewriter.setInsertionPoint(extractRowsOp); input = bufferization::ToTensorOp::create( rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr()) .getResult(); } else { extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering"); return; } int64_t numCols = inputType.getDimSize(1); SmallVector replacements; replacements.reserve(extractRowsOp.getNumResults()); rewriter.setInsertionPoint(extractRowsOp); for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { auto outputType = dyn_cast(output.getType()); if (!outputType) { extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering"); return; } SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto rowSlice = tensor::ExtractSliceOp::create( rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); replacements.push_back(rowSlice.getResult()); } rewriter.replaceOp(extractRowsOp, ValueRange(replacements)); } static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(concatOp); Value concatenated = tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult(); rewriter.replaceOp(concatOp, concatenated); } static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) return failure(); if (requireReturnUse && (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin()))) return failure(); Block& block = computeOp.getBody().front(); if (block.getNumArguments() != 1) return failure(); auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return failure(); SmallVector reverseChain; Value currentValue = yieldOp.getOperands().front(); Value blockArg = block.getArgument(0); while (currentValue != blockArg) { Operation* definingOp = currentValue.getDefiningOp(); if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp)) return failure(); reverseChain.push_back(definingOp); currentValue = definingOp->getOperand(0); } SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) if (!chainSet.contains(&op) && !isa(op)) return failure(); helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); return success(); } static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { return isa(user); })) return false; Block& block = computeOp.getBody().front(); if (block.getNumArguments() != 0) return false; auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return false; rewriter.setInsertionPoint(computeOp); IRMapping mapping; for (Operation& op : block.without_terminator()) { cloneMappedHelperOperands(&op, mapping, rewriter); Operation* clonedOp = rewriter.clone(op, mapping); for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); rewriter.setInsertionPointAfter(clonedOp); } Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); computeOp.getResult(0).replaceAllUsesWith(replacement); return true; } struct ReturnUseInfo { size_t returnIndex; SmallVector helperChain; }; struct ConcatReturnUseInfo { size_t returnIndex; SmallVector sliceOffsets; SmallVector concatShape; SmallVector helperChain; }; static int64_t computeFlatElementIndex(ArrayRef indices, ArrayRef shape) { int64_t flatIndex = 0; for (size_t i = 0; i < shape.size(); ++i) { flatIndex *= shape[i]; flatIndex += indices[i]; } return flatIndex; } static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { SmallVector indices(shape.size(), 0); for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { indices[dim] = flatIndex % shape[dim]; flatIndex /= shape[dim]; } return indices; } static std::optional analyzeReturnUse(Value value) { auto uses = value.getUses(); if (rangeLength(uses) != 1) return std::nullopt; SmallVector helperChain; Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); while (isChannelUseChainOp(currentUser)) { helperChain.push_back(currentUser); auto currentUses = currentUser->getResult(0).getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentValue = currentUser->getResult(0); currentUser = currentUses.begin()->getOwner(); } if (!isa(currentUser)) return std::nullopt; return ReturnUseInfo { currentValue.getUses().begin()->getOperandNumber(), std::move(helperChain), }; } static std::optional analyzeConcatReturnUse(Value value) { auto uses = value.getUses(); if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) return std::nullopt; auto valueType = dyn_cast(value.getType()); if (!valueType || !valueType.hasStaticShape()) return std::nullopt; SmallVector sliceOffsets(valueType.getRank(), 0); SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); while (auto concatOp = dyn_cast(currentUser)) { size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); int64_t axis = concatOp.getDim(); for (Value operand : concatOp.getOperands().take_front(operandIndex)) sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; auto concatType = dyn_cast(concatOp.getResult().getType()); if (!concatType || !concatType.hasStaticShape()) return std::nullopt; concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); currentValue = concatOp.getResult(); auto currentUses = currentValue.getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentUser = currentUses.begin()->getOwner(); } SmallVector helperChain; if (auto helperCompute = dyn_cast(currentUser)) { if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) return std::nullopt; if (failed(collectHelperComputeChain(helperCompute, helperChain))) return std::nullopt; currentValue = helperCompute.getResult(0); auto currentUses = currentValue.getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentUser = currentUses.begin()->getOwner(); } while (isChannelUseChainOp(currentUser)) { helperChain.push_back(currentUser); auto currentUses = currentUser->getResult(0).getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentValue = currentUser->getResult(0); currentUser = currentUses.begin()->getOwner(); } if (!isa(currentUser)) return std::nullopt; return ConcatReturnUseInfo { currentValue.getUses().begin()->getOperandNumber(), std::move(sliceOffsets), std::move(concatShape), std::move(helperChain), }; } static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndices, ArrayRef sourceShape, ArrayRef helperChain, SmallVectorImpl& mappedIndices) { SmallVector currentIndices(sourceIndices.begin(), sourceIndices.end()); SmallVector currentShape(sourceShape.begin(), sourceShape.end()); auto reshapeToResultShape = [&](Operation* op) -> LogicalResult { auto resultType = dyn_cast(op->getResult(0).getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape); currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); currentIndices = expandFlatElementIndex(flatIndex, currentShape); return success(); }; for (Operation* op : helperChain) { if (auto extractSliceOp = dyn_cast(op)) { auto hasStaticValues = [](ArrayRef values) { return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); }; if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) || !hasStaticValues(extractSliceOp.getStaticStrides())) return failure(); SmallVector nextIndices; nextIndices.reserve(currentIndices.size()); for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices, extractSliceOp.getStaticOffsets(), extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides())) { if (stride != 1 || index < offset || index >= offset + size) return failure(); nextIndices.push_back(index - offset); } auto resultType = dyn_cast(extractSliceOp.getResult().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); currentIndices = std::move(nextIndices); currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); continue; } if (auto transposeOp = dyn_cast(op)) { SmallVector nextIndices(currentIndices.size()); SmallVector nextShape(currentShape.size()); for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { int64_t sourceIndex = attr.getInt(); nextIndices[destIndex] = currentIndices[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex]; } currentIndices = std::move(nextIndices); currentShape = std::move(nextShape); continue; } if (auto transposeOp = dyn_cast(op)) { SmallVector nextIndices(currentIndices.size()); SmallVector nextShape(currentShape.size()); for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange())) { int64_t sourceIndex = attr.getInt(); nextIndices[destIndex] = currentIndices[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex]; } currentIndices = std::move(nextIndices); currentShape = std::move(nextShape); continue; } if (isa(op)) { if (failed(reshapeToResultShape(op))) return failure(); continue; } return failure(); } mappedIndices.assign(currentIndices.begin(), currentIndices.end()); return success(); } static void cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { IRMapping mapping; mapping.map(sourceValue, sourceValue); clonedValue = sourceValue; rewriter.setInsertionPointAfterValue(sourceValue); for (Operation* op : helperChain) { cloneMappedHelperOperands(op, mapping, rewriter); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); clonedValue = clonedOp->getResult(0); rewriter.setInsertionPointAfter(clonedOp); } } static Value emitHostCopy(IRRewriter& rewriter, Location loc, Value outputTensor, Value sourceValue, int32_t hostTargetOffset, int32_t deviceSourceOffset, int32_t sizeInBytes) { return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, sourceValue, rewriter.getI32IntegerAttr(hostTargetOffset), rewriter.getI32IntegerAttr(deviceSourceOffset), rewriter.getI32IntegerAttr(sizeInBytes)) .getOutput(); } void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); ConversionTarget target(*ctx); target.addLegalDialect(); { RewritePatternSet patterns(ctx); populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } } { RewritePatternSet patterns(ctx); populateGlobalTensorToMemrefPatterns(patterns); walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto returnOp = cast(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { signalPassFailure(); return; } SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); }); for (auto concatOp : concatOps) lowerConcat(concatOp, rewriter); for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); } for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); runOnComputeBatchOp(computeBatchOp, rewriter); } SmallVector receiveOps; for (auto op : funcOp.getOps()) receiveOps.push_back(op); for (auto receiveOp : receiveOps) { bool onlyPendingRemovalUsers = llvm::all_of( receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); if (onlyPendingRemovalUsers) { markOpToRemove(receiveOp); continue; } if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); continue; } lowerChannelReceive(receiveOp, rewriter); } SmallVector receiveManyOps; for (auto op : funcOp.getOps()) receiveManyOps.push_back(op); for (auto receiveManyOp : receiveManyOps) lowerChannelReceiveMany(receiveManyOp, rewriter); SmallVector sendOps; for (auto op : funcOp.getOps()) sendOps.push_back(op); for (auto sendOp : sendOps) lowerChannelSend(sendOp, rewriter); SmallVector sendManyOps; for (auto op : funcOp.getOps()) sendManyOps.push_back(op); for (auto sendManyOp : sendManyOps) lowerChannelSendMany(sendManyOp, rewriter); SmallVector extractRowsOps; for (auto op : funcOp.getOps()) extractRowsOps.push_back(op); for (auto extractRowsOp : extractRowsOps) lowerExtractRows(extractRowsOp, rewriter); RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { signalPassFailure(); return; } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnOpOperands(returnOp, rewriter); SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); while (!pendingRemovals.empty()) { bool erasedAnyOp = false; for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { Operation* opToRemove = *it; if (!opToRemove->use_empty()) { ++it; continue; } rewriter.eraseOp(opToRemove); it = pendingRemovals.erase(it); erasedAnyOp = true; } if (erasedAnyOp) continue; for (auto opToRemove : pendingRemovals) { opToRemove->dump(); for (auto user : opToRemove->getUsers()) user->dump(); } assert(false && "tracked op removal reached a cycle or missed dependency"); } SmallVector remainingConcatOps; funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); }); for (auto concatOp : remainingConcatOps) lowerConcat(concatOp, rewriter); SmallVector remainingReceiveOps; funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); }); for (auto receiveOp : remainingReceiveOps) lowerChannelReceive(receiveOp, rewriter); SmallVector remainingReceiveManyOps; funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); }); for (auto receiveManyOp : remainingReceiveManyOps) lowerChannelReceiveMany(receiveManyOp, rewriter); SmallVector remainingSendOps; funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); }); for (auto sendOp : remainingSendOps) lowerChannelSend(sendOp, rewriter); SmallVector remainingSendManyOps; funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); }); for (auto sendManyOp : remainingSendManyOps) lowerChannelSendMany(sendManyOp, rewriter); SmallVector remainingExtractRowsOps; funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); }); for (auto extractRowsOp : remainingExtractRowsOps) lowerExtractRows(extractRowsOp, rewriter); // Dump to file for debug bool hasSpatialOps = false; moduleOp.walk([&](Operation* op) { if (op->getDialect()->getNamespace() == "spat") hasSpatialOps = true; }); if (hasSpatialOps) { moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module"); signalPassFailure(); return; } // Dump to file for debug dumpModule(moduleOp, "pim0"); } void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter)) return; SmallVector helperChain; if (succeeded(collectHelperComputeChain(computeOp, helperChain))) return; auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); if (!receiveOp || blockArg.use_empty()) continue; rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); auto outputType = cast(blockArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received = PimReceiveOp::create( rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); blockArg.replaceAllUsesWith(received); markOpToRemove(receiveOp); } if (computeOp.getNumResults() != yieldOp.getNumOperands()) llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { if (result.use_empty()) continue; auto yieldType = cast(yieldValue.getType()); if (auto returnUse = analyzeReturnUse(result)) { Value storedValue = yieldValue; cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); for (Operation* op : returnUse->helperChain) markOpToRemove(op); auto storedType = cast(storedValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; if (auto storedOp = storedValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, storedValue, 0, 0, static_cast(storedType.getNumElements() * elementSize)); continue; } auto resultUses = result.getUses(); if (rangeLength(resultUses) == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast(yieldType.getNumElements() * elementSize)); continue; } if (isa(resultUser)) continue; } if (auto concatReturnUse = analyzeConcatReturnUse(result)) { size_t elementSize = yieldType.getElementTypeBitWidth() / 8; if (concatReturnUse->helperChain.empty()) { rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); emitHostCopy(rewriter, loc, outputTensor, yieldValue, static_cast(flatOffset * elementSize), 0, static_cast(yieldType.getNumElements() * elementSize)); continue; } auto storedType = dyn_cast(yieldValue.getType()); if (!storedType) { computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); signalPassFailure(); return; } rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); for (auto [dim, idx] : llvm::enumerate(sourceIndices)) sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; SmallVector destinationIndices; if (failed(mapIndicesThroughHelperChain(sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); signalPassFailure(); return; } SmallVector extractOffsets; SmallVector extractSizes; SmallVector extractStrides; extractOffsets.reserve(storedType.getRank()); extractSizes.reserve(storedType.getRank()); extractStrides.reserve(storedType.getRank()); for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) { extractOffsets.push_back(rewriter.getIndexAttr(idx)); extractSizes.push_back(rewriter.getIndexAttr(1)); extractStrides.push_back(rewriter.getIndexAttr(1)); } auto scalarTensorType = RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); auto elementSlice = tensor::ExtractSliceOp::create( rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); rewriter.setInsertionPointAfter(elementSlice); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); outputTensor = emitHostCopy(rewriter, loc, outputTensor, elementSlice.getResult(), static_cast(destinationFlatOffset * elementSize), 0, static_cast(elementSize)); } continue; } computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering"); signalPassFailure(); return; } // Use `HaltOp` instead of `YieldOp` rewriter.setInsertionPoint(yieldOp); rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create( rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) if (!blockArg.use_empty()) blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock); PimHaltOp::create(rewriter, computeOp.getLoc()); } void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) { if (std::getenv("PIM_BATCH_LOWER_DEBUG")) llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n"; if (computeBatchOp.getNumResults() != 0) { computeBatchOp.emitOpError( "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); signalPassFailure(); return; } Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); auto oldYield = cast(oldBlock.getTerminator()); if (oldYield.getNumOperands() != 0) { computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); signalPassFailure(); return; } SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); rewriter.setInsertionPointAfter(computeBatchOp); auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, loc, rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), computeBatchOp.getWeights(), computeBatchOp.getInputs()); coreBatchOp.getProperties().setOperandSegmentSizes( {static_cast(computeBatchOp.getWeights().size()), static_cast(computeBatchOp.getInputs().size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; for (BlockArgument arg : oldBlock.getArguments()) { blockArgTypes.push_back(arg.getType()); blockArgLocs.push_back(arg.getLoc()); } Block* newBlock = rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); IRMapping mapper; rewriter.setInsertionPointToStart(newBlock); for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, newArg, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, newArg)) .getOutput(); mapper.map(oldArg, copied); } auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { if (auto mapped = mapper.lookupOrNull(capturedTensor)) return mapped; auto capturedType = cast(capturedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, capturedTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, capturedTensor)) .getOutput(); mapper.map(capturedTensor, copied); return copied; }; rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : oldBlock) { if (isa(op)) continue; if (auto sendBatchOp = dyn_cast(op)) { pim::PimSendBatchOp::create(rewriter, loc, mapper.lookup(sendBatchOp.getInput()), getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), sendBatchOp.getTargetCoreIdsAttr()); continue; } if (auto receiveBatchOp = dyn_cast(op)) { auto outputType = cast(receiveBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); auto received = pim::PimReceiveBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), receiveBatchOp.getSourceCoreIdsAttr()) .getOutput(); mapper.map(receiveBatchOp.getOutput(), received); continue; } if (auto toTensorOp = dyn_cast(op)) { if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { Operation* cloned = rewriter.clone(op, mapper); auto clonedTensor = cloned->getResult(0); auto clonedType = cast(clonedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, clonedTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, clonedTensor)) .getOutput(); mapper.map(toTensorOp.getResult(), copied); continue; } } for (Value operand : op.getOperands()) { if (!isa(operand.getType()) || mapper.contains(operand)) continue; Operation* definingOp = operand.getDefiningOp(); if (definingOp && definingOp->getBlock() == &oldBlock) continue; materializeCapturedTensor(operand); } Operation* cloned = rewriter.clone(op, mapper); for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(originalResult, clonedResult); } rewriter.setInsertionPointToEnd(newBlock); PimHaltOp::create(rewriter, loc); } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) return; auto dpsDefiningOp = dyn_cast(definingOp); if (!dpsDefiningOp) return; auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); if (!tiedOperand) return; Value tiedValue = tiedOperand->get(); assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); tiedValue.setType(newType); self(tiedValue, newType, self); }; funcOp.walk([&](PimVMMOp vmmOp) { auto outTensorOperand = vmmOp.getOutputBuffer(); auto resultTensor = vmmOp.getOutput(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); if (outTensorOperand == vmmOp.getInput()) { rewriter.setInsertionPoint(vmmOp); auto newOutputBuffer = tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); vmmOp.getOutputBufferMutable().assign(newOutputBuffer); } else { enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); outTensorOperand.setType(newType); } resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr oneAttr = rewriter.getIndexAttr(1); IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); SmallVector offsets = {zeroAttr, zeroAttr}; SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { Operation* returnValueDefiningOp = returnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; }); } else { auto outRankedTensorType = llvm::dyn_cast(returnValue.getType()); auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); std::string outputName = "output_" + std::to_string(index); rewriter.setInsertionPoint(returnOp.getParentOp()); memref::GlobalOp::create(rewriter, returnOp.getLoc(), rewriter.getStringAttr(outputName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), {}, {}, {}); outputTensors.push_back( [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); return toTensor.getResult(); }); } } } LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, deviceTensor, inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) continue; for (auto getGlobal : computeOp.getOps()) { if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) { assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); insertMemCopyHostToDev(toTensorOpValue, 0); } } } return success(); } void SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { if (!op) return; bool isExclusivelyOwnedByReturnChain = op->use_empty(); if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = isa(onlyUser) || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) return; if (isChannelUseChainOp(op)) { Value source = op->getOperand(0); markOpToRemove(op); markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); return; } if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); for (Value input : computeOp.getInputs()) markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getOperands()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); } }; SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); auto loc = returnOp.getLoc(); for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); Operation* returnOperand = it.value().getDefiningOp(); rewriter.setInsertionPoint(returnOp); Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); markOwnedReturnChain(returnOperand, markOwnedReturnChain); } } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir