refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+92 -69
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -24,113 +28,132 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size();
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
continue;
}
size_t inputIndex = argIndex - 1 - weightCount;
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
}
for (Operation& op : oldBlock) {
if (isa<pim::PimHaltOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
builder,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
ArrayRef<unsigned> lanes,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
assert(!lanes.empty() && "expected at least one batch lane");
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
OperationFolder constantFolder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
int32_t coreId = coreIds[lanes.front()];
for (unsigned lane : lanes)
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
}
} // namespace onnx_mlir
+3
View File
@@ -9,5 +9,8 @@ namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
llvm::ArrayRef<unsigned> lanes,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+66 -18
View File
@@ -41,15 +41,23 @@ using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
*deviceTargetOffset,
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
*hostSourceOffset,
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
emitMemCopyOp("st",
addressOf(storeOp.getHostTarget(), knowledge),
storeOp.getHostTargetOffset(),
*hostTargetOffset,
addressOf(storeOp.getDeviceSource(), knowledge),
storeOp.getDeviceSourceOffset(),
*deviceSourceOffset,
storeOp.getSize());
}
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp(
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return std::nullopt;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
};
size_t processedOperations = 0;
auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
auto weightIndex = resolveWeightIndex(vmmOp);
if (!weightIndex)
return failure();
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
}
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
SmallVector<size_t> orderedOriginalCoreIds;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
if (inserted)
orderedOriginalCoreIds.push_back(originalCoreId);
it->second.push_back(lane);
}
for (size_t originalCoreId : orderedOriginalCoreIds) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
MemoryReportRow laneRow;
+12 -4
View File
@@ -128,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}