#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include #include #include #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" using namespace mlir; using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { namespace { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" struct SpatialToPimPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: SmallVector> outputTensors; size_t coreId = 0; SmallVector operationsToRemove; void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; } // namespace static bool isChannelUseChainOp(Operation* op) { return isa(op); } static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { for (Value operand : op->getOperands()) { if (mapping.lookupOrNull(operand)) continue; Operation* definingOp = operand.getDefiningOp(); if (!definingOp) continue; if (!isa(definingOp)) continue; Operation* clonedOp = rewriter.clone(*definingOp, mapping); for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); rewriter.setInsertionPointAfter(clonedOp); } } static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return static_cast(spatialCoreIdAttr.getInt()); return static_cast(fallbackCoreId++); } static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); SmallVector coreIds; coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) coreIds.push_back(static_cast(fallbackCoreId++)); return coreIds; } static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId())); rewriter.setInsertionPoint(sendOp); PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); rewriter.eraseOp(sendOp); } static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); return; } auto outputType = cast(receiveOp.getResult().getType()); rewriter.setInsertionPoint(receiveOp); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received = PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); rewriter.replaceOp(receiveOp, received); } static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(sendManyOp); for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { PimSendOp::create(rewriter, sendManyOp.getLoc(), input, getTensorSizeInBytesAttr(rewriter, input), rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId))); } rewriter.eraseOp(sendManyOp); } static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(receiveManyOp); SmallVector replacements; replacements.reserve(receiveManyOp.getNumResults()); for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { auto outputType = cast(output.getType()); Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult(); replacements.push_back( PimReceiveOp::create(rewriter, receiveManyOp.getLoc(), output.getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, output), rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId))) .getOutput()); } rewriter.replaceOp(receiveManyOp, replacements); } static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, int32_t laneCount, IRMapping& mapper, IRRewriter& rewriter) { SmallVector targetCoreIds; targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size()); for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds()) targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); SmallVector mappedInputs; mappedInputs.reserve(sendManyBatchOp.getInputs().size()); for (Value input : sendManyBatchOp.getInputs()) mappedInputs.push_back(mapper.lookup(input)); for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) { SmallVector laneTargetCoreIds; laneTargetCoreIds.reserve(laneCount); for (int32_t lane = 0; lane < laneCount; ++lane) laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]); pim::PimSendBatchOp::create(rewriter, sendManyBatchOp.getLoc(), input, getTensorSizeInBytesAttr(rewriter, input), rewriter.getDenseI32ArrayAttr(laneTargetCoreIds)); } } static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, int32_t laneCount, IRMapping& mapper, IRRewriter& rewriter) { SmallVector sourceCoreIds; sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size()); for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds()) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { auto outputType = cast(output.getType()); Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult(); SmallVector laneSourceCoreIds; laneSourceCoreIds.reserve(laneCount); for (int32_t lane = 0; lane < laneCount; ++lane) laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]); auto received = pim::PimReceiveBatchOp::create(rewriter, receiveManyBatchOp.getLoc(), output.getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, output), rewriter.getDenseI32ArrayAttr(laneSourceCoreIds)) .getOutput(); mapper.map(output, received); } } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(extractRowsOp); auto inputType = cast(extractRowsOp.getInput().getType()); SmallVector replacements; replacements.reserve(extractRowsOp.getNumResults()); for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { auto outputType = cast(output.getType()); SmallVector offsets = { rewriter.getIndexAttr(static_cast(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)), rewriter.getIndexAttr(inputType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; replacements.push_back( tensor::ExtractSliceOp::create( rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides) .getResult()); } rewriter.replaceOp(extractRowsOp, replacements); } static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(concatOp); auto outputType = cast(concatOp.getOutput().getType()); Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult(); Value concatenated = pim::PimConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getOutput().getType(), rewriter.getI64IntegerAttr(concatOp.getAxis()), concatOp.getInputs(), outputBuffer) .getOutput(); rewriter.replaceOp(concatOp, concatenated); } static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector mapOps; funcOp.walk([&](spatial::SpatMapOp mapOp) { if (mapOp->getParentOfType() || mapOp->getParentOfType()) mapOps.push_back(mapOp); }); for (auto mapOp : mapOps) { Block& body = mapOp.getBody().front(); auto yieldOp = cast(body.getTerminator()); SmallVector replacements; replacements.reserve(mapOp.getInputs().size()); rewriter.setInsertionPoint(mapOp); for (Value input : mapOp.getInputs()) { IRMapping mapping; mapping.map(body.getArgument(0), input); for (Operation& bodyOp : body.without_terminator()) { Operation* cloned = rewriter.clone(bodyOp, mapping); for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) mapping.map(originalResult, clonedResult); rewriter.setInsertionPointAfter(cloned); } replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0))); } rewriter.replaceOp(mapOp, replacements); } } static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); packedShape[0] *= count; return RankedTensorType::get(packedShape, elementType.getElementType()); } static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { if (values.empty()) return false; auto firstResult = dyn_cast(values.front()); if (!firstResult) return false; owner = firstResult.getOwner(); startIndex = firstResult.getResultNumber(); for (auto [index, value] : llvm::enumerate(values)) { auto result = dyn_cast(value); if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index) return false; } return true; } static Value createPackedExtractRowsSlice( spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); auto inputType = dyn_cast(extractRowsOp.getInput().getType()); if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) return {}; int64_t rowsPerValue = rowType.getDimSize(0); if (ShapedType::isDynamic(rowsPerValue)) return {}; auto packedType = getPackedTensorType(rowType, static_cast(count)); SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(inputType.getRank()); sizes.reserve(inputType.getRank()); strides.reserve(inputType.getRank()); offsets.push_back(rewriter.getIndexAttr(static_cast(startIndex) * rowsPerValue)); sizes.push_back(rewriter.getIndexAttr(static_cast(count) * rowsPerValue)); strides.push_back(rewriter.getIndexAttr(1)); for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { offsets.push_back(rewriter.getIndexAttr(0)); sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); strides.push_back(rewriter.getIndexAttr(1)); } return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) .getResult(); } static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) { Operation* owner = nullptr; unsigned startIndex = 0; if (!getContiguousOpResults(values, owner, startIndex)) return {}; if (auto extractRowsOp = dyn_cast(owner)) return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); return {}; } static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { auto rowType = dyn_cast(receiveManyOp.getOutputs()[startIndex].getType()); if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0) return {}; auto packedType = getPackedTensorType(rowType, static_cast(count)); auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType()); SmallVector sourceCoreIds; sourceCoreIds.reserve(count); ArrayRef allSourceCoreIds = receiveManyOp.getSourceCoreIds(); for (unsigned index = 0; index < count; ++index) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index])); return pim::PimReceiveTensorOp::create( rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds)) .getOutput(); } static Value createPackedMapTensor( spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc); if (!packedInput) return {}; auto inputType = dyn_cast(mapOp.getInputs()[startIndex].getType()); auto outputType = dyn_cast(mapOp.getOutputs()[startIndex].getType()); if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape() || inputType.getRank() == 0 || outputType.getRank() == 0) return {}; auto packedOutputType = getPackedTensorType(outputType, static_cast(count)); auto packedInit = tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType()); auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upper = arith::ConstantIndexOp::create(rewriter, loc, count); auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()}); { OpBuilder::InsertionGuard guard(rewriter); Block* loopBlock = loop.getBody(); rewriter.setInsertionPointToStart(loopBlock); Value iv = loopBlock->getArgument(0); Value acc = loopBlock->getArgument(1); int64_t inputRowsPerValue = inputType.getDimSize(0); Value inputRowOffset = iv; if (inputRowsPerValue != 1) { auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue); inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); } SmallVector extractOffsets; SmallVector extractSizes; SmallVector extractStrides; extractOffsets.push_back(inputRowOffset); extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue)); extractStrides.push_back(rewriter.getIndexAttr(1)); for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { extractOffsets.push_back(rewriter.getIndexAttr(0)); extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); extractStrides.push_back(rewriter.getIndexAttr(1)); } auto inputSlice = tensor::ExtractSliceOp::create( rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides); IRMapping mapping; Block& body = mapOp.getBody().front(); mapping.map(body.getArgument(0), inputSlice.getResult()); for (Operation& bodyOp : body.without_terminator()) { Operation* cloned = rewriter.clone(bodyOp, mapping); for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) mapping.map(originalResult, clonedResult); rewriter.setInsertionPointAfter(cloned); } auto yieldOp = cast(body.getTerminator()); Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0)); int64_t outputRowsPerValue = outputType.getDimSize(0); Value outputRowOffset = iv; if (outputRowsPerValue != 1) { auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue); outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); } SmallVector insertOffsets; SmallVector insertSizes; SmallVector insertStrides; insertOffsets.push_back(outputRowOffset); insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue)); insertStrides.push_back(rewriter.getIndexAttr(1)); for (int64_t dim = 1; dim < outputType.getRank(); ++dim) { insertOffsets.push_back(rewriter.getIndexAttr(0)); insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim))); insertStrides.push_back(rewriter.getIndexAttr(1)); } auto inserted = tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides); scf::YieldOp::create(rewriter, loc, inserted.getResult()); } return loop.getResult(0); } static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector sendManyOps; funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); }); for (auto sendManyOp : sendManyOps) { if (sendManyOp.getInputs().empty()) continue; rewriter.setInsertionPoint(sendManyOp); Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc()); if (!packedInput) continue; SmallVector targetCoreIds; targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size()); for (int32_t targetCoreId : sendManyOp.getTargetCoreIds()) targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); pim::PimSendTensorOp::create( rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.eraseOp(sendManyOp); } SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); for (auto concatOp : concatOps) { if (concatOp.getAxis() != 0 || concatOp.getInputs().empty()) continue; SmallVector packedInputs; bool changed = false; rewriter.setInsertionPoint(concatOp); for (unsigned index = 0; index < concatOp.getInputs().size();) { Value input = concatOp.getInputs()[index]; auto result = dyn_cast(input); if (!result) { packedInputs.push_back(input); ++index; continue; } Operation* owner = result.getOwner(); unsigned startIndex = result.getResultNumber(); unsigned endIndex = index + 1; while (endIndex < concatOp.getInputs().size()) { auto nextResult = dyn_cast(concatOp.getInputs()[endIndex]); if (!nextResult || nextResult.getOwner() != owner || nextResult.getResultNumber() != startIndex + (endIndex - index)) break; ++endIndex; } unsigned count = endIndex - index; Value packedInput; if (auto mapOp = dyn_cast(owner)) packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc()); else if (auto receiveManyOp = dyn_cast(owner)) packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc()); else if (auto extractRowsOp = dyn_cast(owner)) packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); if (packedInput) { packedInputs.push_back(packedInput); changed = true; } else { for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex) packedInputs.push_back(concatOp.getInputs()[oldIndex]); } index = endIndex; } if (!changed) continue; auto newConcat = pim::PimConcatOp::create( rewriter, concatOp.getLoc(), concatOp.getOutput().getType(), concatOp.getAxisAttr(), ValueRange(packedInputs), createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast(concatOp.getOutput().getType())) .getResult()); rewriter.replaceOp(concatOp, newConcat.getOutput()); } auto eraseUnusedOps = [&](auto tag) { using OpTy = decltype(tag); SmallVector ops; funcOp.walk([&](OpTy op) { ops.push_back(op); }); for (auto op : llvm::reverse(ops)) if (op->use_empty()) rewriter.eraseOp(op); }; eraseUnusedOps(spatial::SpatMapOp {}); eraseUnusedOps(spatial::SpatChannelReceiveManyOp {}); eraseUnusedOps(spatial::SpatExtractRowsOp {}); } static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) return failure(); if (requireReturnUse && (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin()))) return failure(); Block& block = computeOp.getBody().front(); if (block.getNumArguments() != 1) return failure(); auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return failure(); SmallVector reverseChain; Value currentValue = yieldOp.getOperands().front(); Value blockArg = block.getArgument(0); while (currentValue != blockArg) { Operation* definingOp = currentValue.getDefiningOp(); if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp)) return failure(); reverseChain.push_back(definingOp); currentValue = definingOp->getOperand(0); } SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) if (!chainSet.contains(&op) && !isa(op)) return failure(); helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); return success(); } static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { return isa(user); })) return false; Block& block = computeOp.getBody().front(); if (block.getNumArguments() != 0) return false; auto yieldOp = dyn_cast(block.getTerminator()); if (!yieldOp || yieldOp.getNumOperands() != 1) return false; rewriter.setInsertionPoint(computeOp); IRMapping mapping; for (Operation& op : block.without_terminator()) { cloneMappedHelperOperands(&op, mapping, rewriter); Operation* clonedOp = rewriter.clone(op, mapping); for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); rewriter.setInsertionPointAfter(clonedOp); } Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); computeOp.getResult(0).replaceAllUsesWith(replacement); return true; } struct ReturnUseInfo { size_t returnIndex; SmallVector helperChain; }; struct ConcatReturnUseInfo { size_t returnIndex; SmallVector sliceOffsets; SmallVector concatShape; SmallVector concatChain; SmallVector helperChain; }; static int64_t computeFlatElementIndex(ArrayRef indices, ArrayRef shape) { int64_t flatIndex = 0; for (size_t i = 0; i < shape.size(); ++i) { flatIndex *= shape[i]; flatIndex += indices[i]; } return flatIndex; } static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { SmallVector indices(shape.size(), 0); for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { indices[dim] = flatIndex % shape[dim]; flatIndex /= shape[dim]; } return indices; } static std::optional analyzeReturnUse(Value value) { auto uses = value.getUses(); if (rangeLength(uses) != 1) return std::nullopt; SmallVector helperChain; Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); while (isChannelUseChainOp(currentUser)) { helperChain.push_back(currentUser); auto currentUses = currentUser->getResult(0).getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentValue = currentUser->getResult(0); currentUser = currentUses.begin()->getOwner(); } if (!isa(currentUser)) return std::nullopt; return ReturnUseInfo { currentValue.getUses().begin()->getOperandNumber(), std::move(helperChain), }; } static std::optional analyzeConcatReturnUse(Value value) { auto getConcatResult = [](Operation* op) -> Value { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getResult(); if (auto spatialConcat = dyn_cast(op)) return spatialConcat.getOutput(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getOutput(); return {}; }; auto getConcatAxis = [](Operation* op) -> std::optional { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getDim(); if (auto spatialConcat = dyn_cast(op)) return spatialConcat.getAxis(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getAxis(); return std::nullopt; }; auto getConcatOperands = [](Operation* op) -> OperandRange { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getOperands(); if (auto spatialConcat = dyn_cast(op)) return spatialConcat.getInputs(); return cast(op).getInputs(); }; auto uses = value.getUses(); if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) return std::nullopt; auto valueType = dyn_cast(value.getType()); if (!valueType || !valueType.hasStaticShape()) return std::nullopt; SmallVector sliceOffsets(valueType.getRank(), 0); SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); SmallVector concatChain; Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); while (isa(currentUser)) { concatChain.push_back(currentUser); size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); int64_t axis = *getConcatAxis(currentUser); for (Value operand : getConcatOperands(currentUser).take_front(operandIndex)) sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; Value concatResult = getConcatResult(currentUser); auto concatType = dyn_cast(concatResult.getType()); if (!concatType || !concatType.hasStaticShape()) return std::nullopt; concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); currentValue = concatResult; auto currentUses = currentValue.getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentUser = currentUses.begin()->getOwner(); } SmallVector helperChain; if (auto helperCompute = dyn_cast(currentUser)) { if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) return std::nullopt; if (failed(collectHelperComputeChain(helperCompute, helperChain))) return std::nullopt; currentValue = helperCompute.getResult(0); auto currentUses = currentValue.getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentUser = currentUses.begin()->getOwner(); } while (isChannelUseChainOp(currentUser)) { helperChain.push_back(currentUser); auto currentUses = currentUser->getResult(0).getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; currentValue = currentUser->getResult(0); currentUser = currentUses.begin()->getOwner(); } if (!isa(currentUser)) return std::nullopt; return ConcatReturnUseInfo { currentValue.getUses().begin()->getOperandNumber(), std::move(sliceOffsets), std::move(concatShape), std::move(concatChain), std::move(helperChain), }; } static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndices, ArrayRef sourceShape, ArrayRef helperChain, SmallVectorImpl& mappedIndices) { SmallVector currentIndices(sourceIndices.begin(), sourceIndices.end()); SmallVector currentShape(sourceShape.begin(), sourceShape.end()); auto reshapeToResultShape = [&](Operation* op) -> LogicalResult { auto resultType = dyn_cast(op->getResult(0).getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape); currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); currentIndices = expandFlatElementIndex(flatIndex, currentShape); return success(); }; for (Operation* op : helperChain) { if (auto extractSliceOp = dyn_cast(op)) { auto hasStaticValues = [](ArrayRef values) { return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); }; if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) || !hasStaticValues(extractSliceOp.getStaticStrides())) return failure(); SmallVector nextIndices; nextIndices.reserve(currentIndices.size()); for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices, extractSliceOp.getStaticOffsets(), extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides())) { if (stride != 1 || index < offset || index >= offset + size) return failure(); nextIndices.push_back(index - offset); } auto resultType = dyn_cast(extractSliceOp.getResult().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); currentIndices = std::move(nextIndices); currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); continue; } if (auto transposeOp = dyn_cast(op)) { SmallVector nextIndices(currentIndices.size()); SmallVector nextShape(currentShape.size()); for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { int64_t sourceIndex = attr.getInt(); nextIndices[destIndex] = currentIndices[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex]; } currentIndices = std::move(nextIndices); currentShape = std::move(nextShape); continue; } if (auto transposeOp = dyn_cast(op)) { SmallVector nextIndices(currentIndices.size()); SmallVector nextShape(currentShape.size()); for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange())) { int64_t sourceIndex = attr.getInt(); nextIndices[destIndex] = currentIndices[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex]; } currentIndices = std::move(nextIndices); currentShape = std::move(nextShape); continue; } if (isa(op)) { if (failed(reshapeToResultShape(op))) return failure(); continue; } return failure(); } mappedIndices.assign(currentIndices.begin(), currentIndices.end()); return success(); } static void cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { IRMapping mapping; mapping.map(sourceValue, sourceValue); clonedValue = sourceValue; rewriter.setInsertionPointAfterValue(sourceValue); for (Operation* op : helperChain) { cloneMappedHelperOperands(op, mapping, rewriter); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); clonedValue = clonedOp->getResult(0); rewriter.setInsertionPointAfter(clonedOp); } } static Value emitHostCopy(IRRewriter& rewriter, Location loc, Value outputTensor, Value sourceValue, int32_t hostTargetOffset, int32_t deviceSourceOffset, int32_t sizeInBytes) { return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), outputTensor, sourceValue, rewriter.getI32IntegerAttr(hostTargetOffset), rewriter.getI32IntegerAttr(deviceSourceOffset), rewriter.getI32IntegerAttr(sizeInBytes)) .getOutput(); } void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); ConversionTarget target(*ctx); target.addLegalDialect(); { RewritePatternSet patterns(ctx); populateWithGenerated(patterns); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } } { RewritePatternSet patterns(ctx); populateGlobalTensorToMemrefPatterns(patterns); walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto returnOp = cast(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { signalPassFailure(); return; } for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); } for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); runOnComputeBatchOp(computeBatchOp, rewriter); } compactSpatialTensorGroups(funcOp, rewriter); lowerMapOps(funcOp, rewriter); SmallVector receiveOps; for (auto op : funcOp.getOps()) receiveOps.push_back(op); for (auto receiveOp : receiveOps) { bool onlyPendingRemovalUsers = llvm::all_of( receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); if (onlyPendingRemovalUsers) { markOpToRemove(receiveOp); continue; } if (receiveOp->use_empty()) { rewriter.eraseOp(receiveOp); continue; } lowerChannelReceive(receiveOp, rewriter); } SmallVector receiveManyOps; for (auto op : funcOp.getOps()) receiveManyOps.push_back(op); for (auto receiveManyOp : receiveManyOps) lowerChannelReceiveMany(receiveManyOp, rewriter); SmallVector sendOps; for (auto op : funcOp.getOps()) sendOps.push_back(op); for (auto sendOp : sendOps) lowerChannelSend(sendOp, rewriter); SmallVector sendManyOps; for (auto op : funcOp.getOps()) sendManyOps.push_back(op); for (auto sendManyOp : sendManyOps) lowerChannelSendMany(sendManyOp, rewriter); SmallVector extractRowsOps; for (auto op : funcOp.getOps()) extractRowsOps.push_back(op); for (auto extractRowsOp : extractRowsOps) lowerExtractRows(extractRowsOp, rewriter); { RewritePatternSet coreBodyPatterns(ctx); populateWithGenerated(coreBodyPatterns); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } } RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { signalPassFailure(); return; } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnOpOperands(returnOp, rewriter); SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); while (!pendingRemovals.empty()) { bool erasedAnyOp = false; for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { Operation* opToRemove = *it; if (!opToRemove->use_empty()) { ++it; continue; } rewriter.eraseOp(opToRemove); it = pendingRemovals.erase(it); erasedAnyOp = true; } if (erasedAnyOp) continue; for (auto opToRemove : pendingRemovals) { opToRemove->dump(); for (auto user : opToRemove->getUsers()) user->dump(); } assert(false && "tracked op removal reached a cycle or missed dependency"); } compactSpatialTensorGroups(funcOp, rewriter); SmallVector remainingConcatOps; funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); }); for (auto concatOp : remainingConcatOps) lowerConcat(concatOp, rewriter); SmallVector remainingReceiveOps; funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); }); for (auto receiveOp : remainingReceiveOps) lowerChannelReceive(receiveOp, rewriter); SmallVector remainingReceiveManyOps; funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); }); for (auto receiveManyOp : remainingReceiveManyOps) lowerChannelReceiveMany(receiveManyOp, rewriter); SmallVector remainingSendOps; funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); }); for (auto sendOp : remainingSendOps) lowerChannelSend(sendOp, rewriter); SmallVector remainingSendManyOps; funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); }); for (auto sendManyOp : remainingSendManyOps) lowerChannelSendMany(sendManyOp, rewriter); SmallVector remainingExtractRowsOps; funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); }); for (auto extractRowsOp : remainingExtractRowsOps) lowerExtractRows(extractRowsOp, rewriter); // Dump to file for debug bool hasSpatialOps = false; moduleOp.walk([&](Operation* op) { if (op->getDialect()->getNamespace() == "spat") hasSpatialOps = true; }); if (hasSpatialOps) { moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module"); signalPassFailure(); return; } // Dump to file for debug dumpModule(moduleOp, "pim0"); } void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) return; SmallVector helperChain; if (succeeded(collectHelperComputeChain(computeOp, helperChain))) return; auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); if (!receiveOp || blockArg.use_empty()) continue; rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); auto outputType = cast(blockArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received = PimReceiveOp::create( rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) .getOutput(); blockArg.replaceAllUsesWith(received); markOpToRemove(receiveOp); } if (computeOp.getNumResults() != yieldOp.getNumOperands()) llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { if (result.use_empty()) continue; auto yieldType = cast(yieldValue.getType()); if (auto returnUse = analyzeReturnUse(result)) { Value storedValue = yieldValue; cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); for (Operation* op : returnUse->helperChain) markOpToRemove(op); auto storedType = cast(storedValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; if (auto storedOp = storedValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, storedValue, 0, 0, static_cast(storedType.getNumElements() * elementSize)); continue; } auto resultUses = result.getUses(); if (rangeLength(resultUses) == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); emitHostCopy(rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast(yieldType.getNumElements() * elementSize)); continue; } if (isa(resultUser)) continue; } if (auto concatReturnUse = analyzeConcatReturnUse(result)) { size_t elementSize = yieldType.getElementTypeBitWidth() / 8; for (Operation* concatOp : concatReturnUse->concatChain) markOpToRemove(concatOp); if (concatReturnUse->helperChain.empty()) { rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); emitHostCopy(rewriter, loc, outputTensor, yieldValue, static_cast(flatOffset * elementSize), 0, static_cast(yieldType.getNumElements() * elementSize)); continue; } auto storedType = dyn_cast(yieldValue.getType()); if (!storedType) { computeOp.emitOpError( "has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); signalPassFailure(); return; } rewriter.setInsertionPointAfterValue(yieldValue); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); for (auto [dim, idx] : llvm::enumerate(sourceIndices)) sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; SmallVector destinationIndices; if (failed(mapIndicesThroughHelperChain( sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); signalPassFailure(); return; } SmallVector extractOffsets; SmallVector extractSizes; SmallVector extractStrides; extractOffsets.reserve(storedType.getRank()); extractSizes.reserve(storedType.getRank()); extractStrides.reserve(storedType.getRank()); for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) { extractOffsets.push_back(rewriter.getIndexAttr(idx)); extractSizes.push_back(rewriter.getIndexAttr(1)); extractStrides.push_back(rewriter.getIndexAttr(1)); } auto scalarTensorType = RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); auto elementSlice = tensor::ExtractSliceOp::create( rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); rewriter.setInsertionPointAfter(elementSlice); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); outputTensor = emitHostCopy(rewriter, loc, outputTensor, elementSlice.getResult(), static_cast(destinationFlatOffset * elementSize), 0, static_cast(elementSize)); } continue; } computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering"); signalPassFailure(); return; } // Use `HaltOp` instead of `YieldOp` rewriter.setInsertionPoint(yieldOp); rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` SmallVector computeWeights; if (!computeOp.getWeights().empty()) computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create( rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) if (!blockArg.use_empty()) blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock); PimHaltOp::create(rewriter, computeOp.getLoc()); } void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) { if (std::getenv("PIM_BATCH_LOWER_DEBUG")) llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n"; if (computeBatchOp.getNumResults() != 0) { computeBatchOp.emitOpError( "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); signalPassFailure(); return; } Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); auto oldYield = cast(oldBlock.getTerminator()); if (oldYield.getNumOperands() != 0) { computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); signalPassFailure(); return; } SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector batchInputs; if (!computeBatchOp.getInputs().empty()) batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); rewriter.setInsertionPointAfter(computeBatchOp); auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, loc, rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), ValueRange(batchWeights), ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; for (BlockArgument arg : oldBlock.getArguments()) { blockArgTypes.push_back(arg.getType()); blockArgLocs.push_back(arg.getLoc()); } Block* newBlock = rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); IRMapping mapper; rewriter.setInsertionPointToStart(newBlock); for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, newArg, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, newArg)) .getOutput(); mapper.map(oldArg, copied); } auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { if (auto mapped = mapper.lookupOrNull(capturedTensor)) return mapped; auto capturedType = cast(capturedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, capturedTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, capturedTensor)) .getOutput(); mapper.map(capturedTensor, copied); return copied; }; rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : oldBlock) { if (isa(op)) continue; if (auto sendBatchOp = dyn_cast(op)) { pim::PimSendBatchOp::create(rewriter, loc, mapper.lookup(sendBatchOp.getInput()), getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), sendBatchOp.getTargetCoreIdsAttr()); continue; } if (auto sendManyBatchOp = dyn_cast(op)) { lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); continue; } if (auto receiveBatchOp = dyn_cast(op)) { auto outputType = cast(receiveBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); auto received = pim::PimReceiveBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), receiveBatchOp.getSourceCoreIdsAttr()) .getOutput(); mapper.map(receiveBatchOp.getOutput(), received); continue; } if (auto receiveManyBatchOp = dyn_cast(op)) { lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); continue; } if (auto toTensorOp = dyn_cast(op)) { if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { Operation* cloned = rewriter.clone(op, mapper); auto clonedTensor = cloned->getResult(0); auto clonedType = cast(clonedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, clonedTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, clonedTensor)) .getOutput(); mapper.map(toTensorOp.getResult(), copied); continue; } } for (Value operand : op.getOperands()) { if (!isa(operand.getType()) || mapper.contains(operand)) continue; Operation* definingOp = operand.getDefiningOp(); if (definingOp && definingOp->getBlock() == &oldBlock) continue; materializeCapturedTensor(operand); } Operation* cloned = rewriter.clone(op, mapper); for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(originalResult, clonedResult); } rewriter.setInsertionPointToEnd(newBlock); PimHaltOp::create(rewriter, loc); } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); if (!definingOp) return; auto dpsDefiningOp = dyn_cast(definingOp); if (!dpsDefiningOp) return; auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast(value)); if (!tiedOperand) return; Value tiedValue = tiedOperand->get(); assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use"); tiedValue.setType(newType); self(tiedValue, newType, self); }; funcOp.walk([&](PimVMMOp vmmOp) { auto outTensorOperand = vmmOp.getOutputBuffer(); auto resultTensor = vmmOp.getOutput(); auto outShape = getTensorShape(outTensorOperand); assert(isHVectorShape(outShape)); if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); if (outTensorOperand == vmmOp.getInput()) { rewriter.setInsertionPoint(vmmOp); auto newOutputBuffer = tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); vmmOp.getOutputBufferMutable().assign(newOutputBuffer); } else { enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); outTensorOperand.setType(newType); } resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr oneAttr = rewriter.getIndexAttr(1); IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]); IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]); SmallVector offsets = {zeroAttr, zeroAttr}; SmallVector sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector strides = {oneAttr, oneAttr}; rewriter.setInsertionPointAfter(vmmOp); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides); SmallPtrSet exceptions = {vmmOp, sliceOp}; resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); } }); } void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { Value currentReturnValue = returnValue; Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); outputTensors.push_back( [currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; }); } else { auto outRankedTensorType = llvm::dyn_cast(currentReturnValue.getType()); auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); std::string outputName = "output_" + std::to_string(index); rewriter.setInsertionPoint(returnOp.getParentOp()); memref::GlobalOp::create(rewriter, returnOp.getLoc(), rewriter.getStringAttr(outputName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), {}, {}, {}); outputTensors.push_back( [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); return toTensor.getResult(); }); } } } LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create( rewriter, loc, tensorType, deviceTensor, inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) continue; for (auto getGlobal : computeOp.getOps()) { if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) { assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); insertMemCopyHostToDev(toTensorOpValue, 0); } } } return success(); } void SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { if (!op) return; bool isExclusivelyOwnedByReturnChain = op->use_empty(); if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = isa(onlyUser) || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) return; if (isChannelUseChainOp(op)) { Value source = op->getOperand(0); markOpToRemove(op); markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); return; } if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); if (!computeOp.getInputs().empty()) for (Value input : computeOp.getInputs()) markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getOperands()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getInputs()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); return; } if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getInputs()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); } }; SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); auto loc = returnOp.getLoc(); for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); Operation* returnOperand = it.value().getDefiningOp(); rewriter.setInsertionPoint(returnOp); Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); markOwnedReturnChain(returnOperand, markOwnedReturnChain); } } std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir