big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
@@ -1,13 +1,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -34,35 +33,6 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<pim::PimMapOp> mapOps;
funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
}
rewriter.replaceOp(mapOp, replacements);
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
@@ -80,8 +50,6 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
if (funcOp.isExternal())
continue;
expandPimMapOps(funcOp, rewriter);
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
@@ -150,38 +118,11 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimExtractRowsOp, pim::PimConcatOp>(op))
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op)) {
auto inputType = dyn_cast<ShapedType>(extractRowsOp.getInput().getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) {
extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input");
hasFailure = true;
continue;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacementRows;
replacementRows.reserve(extractRowsOp.getOutputs().size());
for (auto rowIndex : llvm::seq<size_t>(0, extractRowsOp.getOutputs().size())) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacementRows.push_back(memref::SubViewOp::create(
rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows));
extractRowsOp->erase();
continue;
}
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;