#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "llvm/ADT/StringRef.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" using namespace mlir; namespace onnx_mlir { namespace { static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); } static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_t laneCount, unsigned lane) { SmallVector laneCoreIds; laneCoreIds.reserve(coreIds.size() / laneCount); for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]); return laneCoreIds; } static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) { if (Value mapped = mapper.lookupOrNull(value)) return mapped; if (auto blockArgument = dyn_cast(value)) { assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning"); assert(false && "unexpected captured block argument while scalarizing pim.core_batch"); } Operation* definingOp = value.getDefiningOp(); assert(definingOp && "expected captured value to be defined by an operation"); assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning"); for (Value operand : definingOp->getOperands()) (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); Operation* cloned = builder.clone(*definingOp, mapper); for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults())) mapper.map(originalResult, clonedResult); return mapper.lookup(value); } static void cloneScalarizedLaneBody(OpBuilder& builder, pim::PimCoreBatchOp coreBatchOp, unsigned lane, OperationFolder& constantFolder) { Block& oldBlock = coreBatchOp.getBody().front(); Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); size_t laneCount = static_cast(coreBatchOp.getLaneCount()); size_t weightCount = coreBatchOp.getWeights().size(); IRMapping mapper; for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { if (blockArg.getType().isIndex()) { mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast(lane), constantFolder)); continue; } if (argIndex <= weightCount) { auto scalarCoreOp = cast(anchorOp); mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1)); continue; } size_t inputIndex = argIndex - 1 - weightCount; assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range"); mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]); } for (Operation& op : oldBlock) { if (isa(op)) continue; for (Value operand : op.getOperands()) (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); if (auto sendBatchOp = dyn_cast(op)) { pim::PimSendOp::create( builder, sendBatchOp.getLoc(), mapper.lookup(sendBatchOp.getInput()), sendBatchOp.getSizeAttr(), getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder)); continue; } if (auto sendTensorBatchOp = dyn_cast(op)) { pim::PimSendTensorOp::create( builder, sendTensorBatchOp.getLoc(), mapper.lookup(sendTensorBatchOp.getInput()), builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); continue; } if (auto receiveBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveOp::create( builder, receiveBatchOp.getLoc(), receiveBatchOp.getOutput().getType(), mapper.lookup(receiveBatchOp.getOutputBuffer()), receiveBatchOp.getSizeAttr(), getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder)); mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); continue; } if (auto receiveTensorBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveTensorOp::create( builder, receiveTensorBatchOp.getLoc(), receiveTensorBatchOp.getOutput().getType(), mapper.lookup(receiveTensorBatchOp.getOutputBuffer()), builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput()); continue; } if (auto memcpBatchOp = dyn_cast(op)) { auto scalarCopy = pim::PimMemCopyHostToDevOp::create( builder, memcpBatchOp.getLoc(), memcpBatchOp.getOutput().getType(), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder), mapper.lookup(memcpBatchOp.getDeviceTarget()), mapper.lookup(memcpBatchOp.getHostSource()), memcpBatchOp.getSizeAttr()); mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); continue; } Operation* cloned = builder.clone(op, mapper); for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(originalResult, clonedResult); } } } // namespace LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, ArrayRef lanes, llvm::function_ref callback) { assert(!lanes.empty() && "expected at least one batch lane"); OwningOpRef scratchModule = ModuleOp::create(coreBatchOp.getLoc()); OpBuilder builder(scratchModule->getContext()); OperationFolder constantFolder(scratchModule->getContext()); builder.setInsertionPointToStart(scratchModule->getBody()); SmallVector weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end()); auto coreIds = getBatchCoreIds(coreBatchOp); int32_t coreId = coreIds[lanes.front()]; for (unsigned lane : lanes) assert(coreIds[lane] == coreId && "all grouped lanes must target the same core"); auto scalarCore = pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); SmallVector weightTypes; SmallVector weightLocs; weightTypes.reserve(weights.size()); weightLocs.reserve(weights.size()); for (Value weight : weights) { weightTypes.push_back(weight.getType()); weightLocs.push_back(weight.getLoc()); } Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs); builder.setInsertionPointToEnd(block); for (unsigned lane : lanes) cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); if (block->empty() || !isa(block->back())) pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); return callback(scalarCore); } LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane, llvm::function_ref callback) { return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef {lane}, callback); } } // namespace onnx_mlir