194 lines
8.0 KiB
C++
194 lines
8.0 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
|
}
|
|
|
|
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
|
SmallVector<int32_t> laneCoreIds;
|
|
laneCoreIds.reserve(coreIds.size() / laneCount);
|
|
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
|
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
|
return laneCoreIds;
|
|
}
|
|
|
|
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
|
|
if (Value mapped = mapper.lookupOrNull(value))
|
|
return mapped;
|
|
|
|
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
|
|
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
|
|
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
|
|
}
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
assert(definingOp && "expected captured value to be defined by an operation");
|
|
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
|
|
|
|
for (Value operand : definingOp->getOperands())
|
|
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
|
|
|
Operation* cloned = builder.clone(*definingOp, mapper);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
|
mapper.map(originalResult, clonedResult);
|
|
return mapper.lookup(value);
|
|
}
|
|
|
|
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
|
pim::PimCoreBatchOp coreBatchOp,
|
|
unsigned lane,
|
|
OperationFolder& constantFolder) {
|
|
Block& oldBlock = coreBatchOp.getBody().front();
|
|
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
|
size_t weightCount = coreBatchOp.getWeights().size();
|
|
|
|
IRMapping mapper;
|
|
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
|
if (blockArg.getType().isIndex()) {
|
|
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
|
|
continue;
|
|
}
|
|
|
|
if (argIndex <= weightCount) {
|
|
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
|
|
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
|
|
continue;
|
|
}
|
|
|
|
size_t inputIndex = argIndex - 1 - weightCount;
|
|
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
|
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
|
}
|
|
|
|
for (Operation& op : oldBlock) {
|
|
if (isa<pim::PimHaltOp>(op))
|
|
continue;
|
|
|
|
for (Value operand : op.getOperands())
|
|
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
|
|
|
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
|
pim::PimSendOp::create(
|
|
builder,
|
|
sendBatchOp.getLoc(),
|
|
mapper.lookup(sendBatchOp.getInput()),
|
|
sendBatchOp.getSizeAttr(),
|
|
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
|
continue;
|
|
}
|
|
|
|
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
|
pim::PimSendTensorOp::create(
|
|
builder,
|
|
sendTensorBatchOp.getLoc(),
|
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
|
continue;
|
|
}
|
|
|
|
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
|
auto scalarReceive = pim::PimReceiveOp::create(
|
|
builder,
|
|
receiveBatchOp.getLoc(),
|
|
receiveBatchOp.getOutput().getType(),
|
|
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
|
receiveBatchOp.getSizeAttr(),
|
|
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
|
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
|
continue;
|
|
}
|
|
|
|
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
|
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
|
builder,
|
|
receiveTensorBatchOp.getLoc(),
|
|
receiveTensorBatchOp.getOutput().getType(),
|
|
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
|
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
|
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
|
continue;
|
|
}
|
|
|
|
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
|
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
|
builder,
|
|
memcpBatchOp.getLoc(),
|
|
memcpBatchOp.getOutput().getType(),
|
|
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
|
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
|
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
|
mapper.lookup(memcpBatchOp.getHostSource()),
|
|
memcpBatchOp.getSizeAttr());
|
|
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
|
continue;
|
|
}
|
|
|
|
Operation* cloned = builder.clone(op, mapper);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
|
mapper.map(originalResult, clonedResult);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
|
ArrayRef<unsigned> lanes,
|
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
|
assert(!lanes.empty() && "expected at least one batch lane");
|
|
|
|
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
|
OpBuilder builder(scratchModule->getContext());
|
|
OperationFolder constantFolder(scratchModule->getContext());
|
|
builder.setInsertionPointToStart(scratchModule->getBody());
|
|
|
|
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
|
auto coreIds = getBatchCoreIds(coreBatchOp);
|
|
int32_t coreId = coreIds[lanes.front()];
|
|
for (unsigned lane : lanes)
|
|
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
|
|
|
auto scalarCore =
|
|
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
|
SmallVector<Type> weightTypes;
|
|
SmallVector<Location> weightLocs;
|
|
weightTypes.reserve(weights.size());
|
|
weightLocs.reserve(weights.size());
|
|
for (Value weight : weights) {
|
|
weightTypes.push_back(weight.getType());
|
|
weightLocs.push_back(weight.getLoc());
|
|
}
|
|
Block* block =
|
|
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
|
|
builder.setInsertionPointToEnd(block);
|
|
for (unsigned lane : lanes)
|
|
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
|
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
|
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
|
return callback(scalarCore);
|
|
}
|
|
|
|
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
|
unsigned lane,
|
|
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
|
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|