#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" using namespace mlir; namespace onnx_mlir { namespace { static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); } static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_t laneCount, unsigned lane) { SmallVector laneCoreIds; laneCoreIds.reserve(coreIds.size() / laneCount); for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]); return laneCoreIds; } static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) { IRRewriter rewriter(scalarCore.getContext()); SmallVector batchOps; scalarCore.walk([&](Operation* op) { if (isa(op)) { batchOps.push_back(op); } }); for (Operation* op : batchOps) { rewriter.setInsertionPoint(op); if (auto sendBatchOp = dyn_cast(op)) { pim::PimSendOp::create(rewriter, sendBatchOp.getLoc(), sendBatchOp.getInput(), sendBatchOp.getSizeAttr(), rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); rewriter.eraseOp(op); continue; } if (auto sendTensorBatchOp = dyn_cast(op)) { pim::PimSendTensorOp::create( rewriter, sendTensorBatchOp.getLoc(), sendTensorBatchOp.getInput(), rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); rewriter.eraseOp(op); continue; } if (auto receiveBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveOp::create(rewriter, receiveBatchOp.getLoc(), receiveBatchOp.getOutput().getType(), receiveBatchOp.getOutputBuffer(), receiveBatchOp.getSizeAttr(), rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); rewriter.replaceOp(op, scalarReceive->getResults()); continue; } if (auto receiveTensorBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveTensorOp::create( rewriter, receiveTensorBatchOp.getLoc(), receiveTensorBatchOp.getOutput().getType(), receiveTensorBatchOp.getOutputBuffer(), rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); rewriter.replaceOp(op, scalarReceive->getResults()); continue; } auto memcpBatchOp = cast(op); auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter, memcpBatchOp.getLoc(), memcpBatchOp.getOutput().getType(), memcpBatchOp.getDeviceTarget(), memcpBatchOp.getHostSource(), memcpBatchOp.getDeviceTargetOffsetAttr(), memcpBatchOp.getHostSourceOffsetAttr(), memcpBatchOp.getSizeAttr()); rewriter.replaceOp(op, scalarCopy->getResults()); } } } // namespace LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane, llvm::function_ref callback) { OwningOpRef scratchModule = ModuleOp::create(coreBatchOp.getLoc()); OpBuilder builder(scratchModule->getContext()); builder.setInsertionPointToStart(scratchModule->getBody()); size_t laneCount = static_cast(coreBatchOp.getLaneCount()); size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount; SmallVector laneWeights; laneWeights.reserve(weightsPerLane); for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]); auto coreIds = getBatchCoreIds(coreBatchOp); auto scalarCore = pim::PimCoreOp::create( builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane])); Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); IRMapping mapper; if (coreBatchOp.getBody().front().getNumArguments() == 1) mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]); builder.setInsertionPointToEnd(block); for (Operation& op : coreBatchOp.getBody().front()) { Operation* cloned = builder.clone(op, mapper); for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(originalResult, clonedResult); } if (block->empty() || !isa(block->back())) pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); scalarizeBatchOpsInCore(scalarCore, laneCount, lane); return callback(scalarCore); } } // namespace onnx_mlir