compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s

This commit is contained in:
NiccoloN
2026-05-06 17:16:51 +02:00
parent 7bb58e80de
commit f2fe147961
13 changed files with 2264 additions and 307 deletions
@@ -1,9 +1,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
@@ -31,6 +34,35 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<pim::PimMapOp> mapOps;
funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
}
rewriter.replaceOp(mapOp, replacements);
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
@@ -41,13 +73,15 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
IRRewriter rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
expandPimMapOps(funcOp, rewriter);
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
@@ -113,6 +147,45 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
}
}
}
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimExtractRowsOp, pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op)) {
auto inputType = dyn_cast<ShapedType>(extractRowsOp.getInput().getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) {
extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input");
hasFailure = true;
continue;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacementRows;
replacementRows.reserve(extractRowsOp.getOutputs().size());
for (auto rowIndex : llvm::seq<size_t>(0, extractRowsOp.getOutputs().size())) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacementRows.push_back(memref::SubViewOp::create(
rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows));
extractRowsOp->erase();
continue;
}
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {