compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s

This commit is contained in:
NiccoloN
2026-05-06 17:16:51 +02:00
parent 7bb58e80de
commit f2fe147961
13 changed files with 2264 additions and 307 deletions
+120 -5
View File
@@ -37,6 +37,11 @@ using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
@@ -382,10 +387,75 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds()))
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
size_t elementSize = inputType.getElementTypeBitWidth() / 8;
size_t rowSizeInBytes = static_cast<size_t>(inputType.getDimSize(1)) * elementSize;
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
emitMemCopyOp("lmv",
addressOf(outputBuffer, knowledge),
0,
inputAddr,
rowIndex * rowSizeInBytes,
rowSizeInBytes,
"len");
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
@@ -396,11 +466,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
@@ -682,6 +747,25 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto sendManyBatchOp = dyn_cast<pim::PimSendManyBatchOp>(op)) {
SmallVector<int32_t> laneTargetCoreIds;
laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, sendManyBatchOp.getInputs().size()))
laneTargetCoreIds.push_back(
sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (mlir::Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyOp::create(builder,
sendManyBatchOp.getLoc(),
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
ValueRange(mappedInputs));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
@@ -694,6 +778,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto receiveManyBatchOp = dyn_cast<pim::PimReceiveManyBatchOp>(op)) {
SmallVector<int32_t> laneSourceCoreIds;
laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, receiveManyBatchOp.getOutputs().size()))
laneSourceCoreIds.push_back(
receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedOutputBuffers;
mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size());
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
auto scalarReceiveMany =
pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
mapper.map(originalOutput, scalarOutput);
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
@@ -812,8 +919,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))