fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s

This commit is contained in:
NiccoloN
2026-05-08 14:21:45 +02:00
parent 58e6587697
commit b1272d2283
7 changed files with 541 additions and 81 deletions
+36 -22
View File
@@ -178,7 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
}
}
void PimMemory::remove(mlir::Value val) {
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
globalMemEntriesMap.erase(removeIter);
@@ -370,11 +369,21 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp,
const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] :
llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
@@ -384,7 +393,15 @@ void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticVa
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp,
const StaticValueKnowledge& knowledge) const {
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
@@ -393,13 +410,8 @@ void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
emitMemCopyOp("lmv",
addressOf(outputBuffer, knowledge),
0,
inputAddr,
rowIndex * rowSizeInBytes,
rowSizeInBytes,
"len");
emitMemCopyOp(
"lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len");
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
@@ -742,10 +754,8 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
for (mlir::Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyOp::create(builder,
sendManyBatchOp.getLoc(),
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
ValueRange(mappedInputs));
pim::PimSendManyOp::create(
builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs));
continue;
}
@@ -773,13 +783,13 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
auto scalarReceiveMany =
pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] :
llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
mapper.map(originalOutput, scalarOutput);
continue;
}
@@ -904,10 +914,14 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
+3
View File
@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/JSON.h"
@@ -117,8 +118,10 @@ public:
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;