big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
+89 -190
View File
@@ -42,6 +42,79 @@ static size_t getValueSizeInBytes(mlir::Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
static SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
static bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
static FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
@@ -97,11 +170,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) {
if (!emptyManyOp->getParentOfType<pim::PimCoreOp>() && !emptyManyOp->getParentOfType<pim::PimCoreBatchOp>())
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory();
@@ -111,10 +179,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
op->walk([&](pim::PimEmptyManyOp emptyManyOp) {
for (mlir::Value output : emptyManyOp.getOutputs())
gatherMemEntry(output);
});
allocateGatheredMemory();
}
@@ -369,13 +433,6 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp,
const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] :
llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
@@ -388,11 +445,6 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds()))
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
@@ -400,20 +452,6 @@ void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const St
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp,
const StaticValueKnowledge& knowledge) const {
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
size_t elementSize = inputType.getElementTypeBitWidth() / 8;
size_t rowSizeInBytes = static_cast<size_t>(inputType.getDimSize(1)) * elementSize;
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
emitMemCopyOp(
"lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len");
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
@@ -742,23 +780,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto sendManyBatchOp = dyn_cast<pim::PimSendManyBatchOp>(op)) {
SmallVector<int32_t> laneTargetCoreIds;
laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, sendManyBatchOp.getInputs().size()))
laneTargetCoreIds.push_back(
sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (mlir::Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyOp::create(
builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
@@ -771,29 +792,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto receiveManyBatchOp = dyn_cast<pim::PimReceiveManyBatchOp>(op)) {
SmallVector<int32_t> laneSourceCoreIds;
laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, receiveManyBatchOp.getOutputs().size()))
laneSourceCoreIds.push_back(
receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedOutputBuffers;
mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size());
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] :
llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
mapper.map(originalOutput, scalarOutput);
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
@@ -912,18 +910,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
@@ -954,8 +946,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else if (isa<pim::PimEmptyManyOp>(op))
return success();
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
@@ -967,84 +957,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
/// Write crossbar weight matrices as padded binary files for a single core.
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
StringRef coreWeightsDirPath,
json::Array& xbarsPerGroup) {
int64_t xbarSize = crossbarSize.getValue();
std::error_code errorCode;
size_t weightIndex = 0;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
return CompilerSuccess;
}
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
@@ -1079,45 +991,31 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(!getGlobalOp && "Weight is not from a memref.get_global");
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
assert(!globalOp && "Could not find memref.global");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
assert(!initialValue && "memref.global has no initial value");
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
assert(!denseAttr && "memref.global initial value is not dense");
}
if (mapGlobalOpToFileName.contains(globalOp)) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreId].insert(weightToFile);
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
@@ -1132,8 +1030,8 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
@@ -1144,7 +1042,8 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
}
weightFileStream.close();
mapGlobalOpToFileName.insert({globalOp, newFileName});
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
}
-5
View File
@@ -24,8 +24,6 @@ class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
@@ -117,12 +115,9 @@ public:
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>