#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { template static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, OperationFolder& constantFolder, bool& hasFailure) { DenseMap>> materializedValues; DominanceInfo dominance(coreOp); SmallVector ops; coreOp.getBody().front().walk([&](Operation* op) { if (!isa(op)) ops.push_back(op); }); for (Operation* op : ops) { if (auto loadOp = dyn_cast(op); loadOp && loadOp.getType().isIndex()) continue; for (OpOperand& operand : op->getOpOperands()) { Value originalValue = operand.get(); if (!isa(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber())) continue; auto resolvedAddress = resolveContiguousAddress(originalValue); if (failed(resolvedAddress)) continue; auto getGlobalOp = dyn_cast_or_null(resolvedAddress->base.getDefiningOp()); if (!getGlobalOp) continue; auto originalType = dyn_cast(originalValue.getType()); if (!originalType || !originalType.hasStaticShape()) { op->emitOpError("host constant materialization requires a static memref operand"); hasFailure = true; continue; } auto& cachedByOffset = materializedValues[resolvedAddress->base]; auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; auto cachedValue = cachedByType.find(originalType); if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) { operand.set(cachedValue->second); continue; } int64_t totalBytes = -1; if (auto type = dyn_cast(originalValue.getType()); type && type.hasStaticShape()) totalBytes = static_cast(getShapedTypeSizeInBytes(type)); if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); hasFailure = true; continue; } auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType()); rewriter.setInsertionPoint(op); Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType); Value deviceDst = localAlloc; if (contiguousType != originalType) deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc); Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0); Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset); Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter, op->getLoc(), originalType, zeroOffset, hostOffset, deviceDst, getGlobalOp.getResult(), rewriter.getI32IntegerAttr(static_cast(totalBytes))) .getOutput(); cachedByType[originalType] = copiedValue; operand.set(copiedValue); } } } struct MaterializeHostConstantsPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) StringRef getArgument() const override { return "materialize-pim-host-constants"; } StringRef getDescription() const override { return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops"; } void runOnOperation() override { ModuleOp moduleOp = getOperation(); IRRewriter rewriter(moduleOp.getContext()); OperationFolder constantFolder(moduleOp.getContext()); bool hasFailure = false; for (func::FuncOp funcOp : moduleOp.getOps()) { if (funcOp.isExternal()) continue; for (pim::PimCoreOp coreOp : funcOp.getOps()) materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure); for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps()) materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure); SmallVector hostCompactOps; for (Operation& op : funcOp.getBody().front()) if (isa(op)) hostCompactOps.push_back(&op); for (Operation* op : hostCompactOps) { rewriter.setInsertionPoint(op); auto concatOp = cast(op); concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization"); hasFailure = true; } } if (hasFailure) { moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above"); signalPassFailure(); return; } dumpModule(moduleOp, "pim4_materialized"); } }; } // namespace std::unique_ptr createPimMaterializeHostConstantsPass() { return std::make_unique(); } } // namespace onnx_mlir