reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-04-10 18:50:25 +02:00
parent f3a36e9d43
commit f054e66ed0
18 changed files with 623 additions and 241 deletions

View File

@@ -84,8 +84,8 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
auto resolvedAddress = resolveContiguousAddress(value);
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
@@ -199,47 +199,49 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("ld",
memory.getValueAddress(loadOp.getDeviceTarget()),
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
memory.getValueAddress(loadOp.getHostSource()),
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("st",
memory.getValueAddress(storeOp.getHostTarget()),
addressOf(storeOp.getHostTarget(), knowledge),
storeOp.getHostTargetOffset(),
memory.getValueAddress(storeOp.getDeviceSource()),
addressOf(storeOp.getDeviceSource(), knowledge),
storeOp.getDeviceSourceOffset(),
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getTarget()),
addressOf(lmvOp.getTarget(), knowledge),
lmvOp.getTargetOffset(),
memory.getValueAddress(lmvOp.getSource()),
addressOf(lmvOp.getSource(), knowledge),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
emitMvmOp(
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
bool transposeMatrix,
const StaticValueKnowledge& knowledge) {
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
@@ -249,10 +251,10 @@ static size_t getValueSizeInBytes(mlir::Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -265,10 +267,10 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -281,10 +283,10 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -297,10 +299,10 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -313,10 +315,10 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -329,9 +331,9 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -344,9 +346,9 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -358,9 +360,9 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -372,9 +374,9 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -386,9 +388,9 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput());
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -400,9 +402,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
@@ -510,57 +512,58 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
}
/// Dispatch all operations in a core region to the appropriate code generator.
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return -1;
}
processedOperations++;
}
return processedOperations;
auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp, knowledge);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp, knowledge);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return failure();
}
processedOperations++;
return success();
});
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
/// Write crossbar weight matrices as padded binary files for a single core.
@@ -739,7 +742,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp, coreCodeGen);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);