reimplement pool lowering

add pool validation
align PIM ops/codegen/parser with the ISA
move constant materialization to MLIR
rename the PIM verification/materialization passes
better folded-constant handling
This commit is contained in:
NiccoloN
2026-03-23 19:14:50 +01:00
parent 461bdd808d
commit 661170a9aa
30 changed files with 912 additions and 512 deletions

View File

@@ -5,7 +5,8 @@ add_pim_library(OMPimPasses
PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp
PimHostVerificationPass.cpp
PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -120,20 +120,8 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
size_t elementByteWidth = initType.getElementTypeBitWidth() / 8;
if (elementByteWidth == 0)
return failure();
size_t totalBytes = initType.getNumElements() * elementByteWidth;
rewriter.setInsertionPoint(mapOp);
pim::PimMemCopyHostToDevOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
rewriter.eraseOp(mapOp);
return success();
}

View File

@@ -0,0 +1,135 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
if (isa<pim::PimHaltOp>(op))
continue;
for (OpOperand& operand : op.getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op.emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
operand.set(cachedValue->second);
continue;
}
int64_t totalBytes = getValueSizeInBytes(originalValue);
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(&op);
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
op.getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(totalBytes)));
cachedByType[originalType] = hostToDevCopy.getResult();
operand.set(hostToDevCopy.getResult());
}
}
}
}
if (hasFailure) {
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -17,7 +17,9 @@ std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -35,16 +35,24 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return false;
}
StringRef getArgument() const override { return "verify-pim-host-pass"; }
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that no runtime host-side code remains in bufferized PIM IR";
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
PimHostVerificationPass() {}
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
PimVerificationPass() {}
PimVerificationPass(const PimVerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
@@ -132,11 +140,27 @@ private:
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
if (succeeded(resolveContiguousAddress(operand)))
continue;
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
@@ -165,6 +189,6 @@ private:
} // namespace
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
} // namespace onnx_mlir