better MaterializeMergeSchedule.cpp with %lane indexed batch computes

support for tensors of index values
This commit is contained in:
NiccoloN
2026-05-22 21:52:28 +02:00
parent 495186503c
commit c77ffa9c56
20 changed files with 398 additions and 300 deletions
+18 -28
View File
@@ -41,23 +41,10 @@ using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
size_t allocSize = getShapedTypeSizeInBytes(type);
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
@@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
/ receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
@@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
/ sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
@@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
@@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
emitInstruction(instruction);
}
@@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
emitInstruction(instruction);
}
@@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
emitInstruction(instruction);
}
@@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
emitInstruction(instruction);
}
@@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
emitInstruction(instruction);
}
@@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
instruction.r1 = 1;
instruction.r2OrImm = 1;
instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
emitInstruction(instruction);
}
@@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
emitInstruction(instruction);
}
@@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
emitInstruction(instruction);
}
@@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput()));
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
emitInstruction(instruction);
}
@@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput()));
instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction);
}
@@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].