better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
This commit is contained in:
@@ -41,23 +41,10 @@ using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (elementType.isIndex())
|
||||
return sizeof(int64_t);
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() / 8;
|
||||
llvm_unreachable("unsupported shaped element type");
|
||||
}
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
@@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
||||
/ receiveTensorOp.getSourceCoreIds().size();
|
||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||
}
|
||||
@@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
||||
/ sendTensorOp.getTargetCoreIds().size();
|
||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||
}
|
||||
@@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
|
||||
|
||||
int64_t axis = concatOp.getAxis();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
|
||||
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||
|
||||
size_t outerCount = 1;
|
||||
@@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 1;
|
||||
instruction.generic1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput()));
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput()));
|
||||
instruction.generic3 =
|
||||
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
||||
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||
auto srcShape = srcType.getShape();
|
||||
size_t rank = srcShape.size();
|
||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||
|
||||
Reference in New Issue
Block a user