Files
Raptor/src/PIM/Compiler/PimBatchEmission.cpp
T
NiccoloN 43ed3914b8
Validate Operations / validate-operations (push) Has been cancelled
better MaterializeMergeSchedule.cpp (something still broken downstream)
2026-05-22 06:56:39 +02:00

194 lines
8.0 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
if (Value mapped = mapper.lookupOrNull(value))
return mapped;
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
}
Operation* definingOp = value.getDefiningOp();
assert(definingOp && "expected captured value to be defined by an operation");
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
for (Value operand : definingOp->getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
Operation* cloned = builder.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front();
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size();
IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
continue;
}
size_t inputIndex = argIndex - 1 - weightCount;
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
}
for (Operation& op : oldBlock) {
if (isa<pim::PimHaltOp>(op))
continue;
for (Value operand : op.getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
builder,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
ArrayRef<unsigned> lanes,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
assert(!lanes.empty() && "expected at least one batch lane");
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
OperationFolder constantFolder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
auto coreIds = getBatchCoreIds(coreBatchOp);
int32_t coreId = coreIds[lanes.front()];
for (unsigned lane : lanes)
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
SmallVector<Type> weightTypes;
SmallVector<Location> weightLocs;
weightTypes.reserve(weights.size());
weightLocs.reserve(weights.size());
for (Value weight : weights) {
weightTypes.push_back(weight.getType());
weightLocs.push_back(weight.getLoc());
}
Block* block =
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
builder.setInsertionPointToEnd(block);
for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return callback(scalarCore);
}
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
}
} // namespace onnx_mlir