This commit is contained in:
@@ -37,6 +37,11 @@ using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
@@ -382,10 +387,75 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
|
||||
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const {
|
||||
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds()))
|
||||
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
|
||||
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
|
||||
|
||||
size_t elementSize = inputType.getElementTypeBitWidth() / 8;
|
||||
size_t rowSizeInBytes = static_cast<size_t>(inputType.getDimSize(1)) * elementSize;
|
||||
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
|
||||
|
||||
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
|
||||
emitMemCopyOp("lmv",
|
||||
addressOf(outputBuffer, knowledge),
|
||||
0,
|
||||
inputAddr,
|
||||
rowIndex * rowSizeInBytes,
|
||||
rowSizeInBytes,
|
||||
"len");
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
||||
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
||||
|
||||
int64_t axis = concatOp.getAxis();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
||||
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||
|
||||
size_t outerCount = 1;
|
||||
for (int64_t dim = 0; dim < axis; ++dim)
|
||||
outerCount *= static_cast<size_t>(outputShape[dim]);
|
||||
|
||||
size_t innerCount = 1;
|
||||
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
|
||||
innerCount *= static_cast<size_t>(outputShape[dim]);
|
||||
|
||||
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
|
||||
size_t concatOffset = 0;
|
||||
for (mlir::Value input : concatOp.getInputs()) {
|
||||
auto inputType = cast<ShapedType>(input.getType());
|
||||
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
|
||||
|
||||
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
|
||||
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
|
||||
size_t inputAddr = addressOf(input, knowledge);
|
||||
|
||||
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
|
||||
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
|
||||
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
|
||||
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
|
||||
}
|
||||
|
||||
concatOffset += inputConcatDim;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename MVMTy>
|
||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
||||
MVMTy mvmLikeOp,
|
||||
@@ -396,11 +466,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||
}
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
|
||||
@@ -682,6 +747,25 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendManyBatchOp = dyn_cast<pim::PimSendManyBatchOp>(op)) {
|
||||
SmallVector<int32_t> laneTargetCoreIds;
|
||||
laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (auto valueIndex : llvm::seq<size_t>(0, sendManyBatchOp.getInputs().size()))
|
||||
laneTargetCoreIds.push_back(
|
||||
sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
|
||||
|
||||
SmallVector<mlir::Value> mappedInputs;
|
||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (mlir::Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
|
||||
pim::PimSendManyOp::create(builder,
|
||||
sendManyBatchOp.getLoc(),
|
||||
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
|
||||
ValueRange(mappedInputs));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(builder,
|
||||
@@ -694,6 +778,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveManyBatchOp = dyn_cast<pim::PimReceiveManyBatchOp>(op)) {
|
||||
SmallVector<int32_t> laneSourceCoreIds;
|
||||
laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size());
|
||||
for (auto valueIndex : llvm::seq<size_t>(0, receiveManyBatchOp.getOutputs().size()))
|
||||
laneSourceCoreIds.push_back(
|
||||
receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
|
||||
|
||||
SmallVector<mlir::Value> mappedOutputBuffers;
|
||||
mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size());
|
||||
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
|
||||
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
|
||||
|
||||
auto scalarReceiveMany =
|
||||
pim::PimReceiveManyOp::create(builder,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
receiveManyBatchOp->getResultTypes(),
|
||||
ValueRange(mappedOutputBuffers),
|
||||
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
|
||||
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
|
||||
mapper.map(originalOutput, scalarOutput);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||
if (!hostSource)
|
||||
@@ -812,8 +919,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
|
||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
|
||||
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
|
||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
||||
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
|
||||
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
|
||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
|
||||
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||
|
||||
Reference in New Issue
Block a user