#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool isAddressOnlyHostOp(Operation* op) { return isa(op); } static bool isCodegenAddressableValue(Value value) { auto resolvedAddress = resolveContiguousAddress(value); if (failed(resolvedAddress)) return false; return isa(resolvedAddress->base) || isa(resolvedAddress->base.getDefiningOp()); } static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) return operandIndex == 1; if (isa(op)) return operandIndex == 1; if (isa(op)) return operandIndex == 0; return false; } struct VerificationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) StringRef getArgument() const override { return "verify-pim-pass"; } StringRef getDescription() const override { return "Verify that bufferized PIM IR contains only explicit host/device transfers"; } VerificationPass() {} VerificationPass(const VerificationPass& pass) {} void runOnOperation() override { ModuleOp moduleOp = getOperation(); bool hasFailure = false; for (func::FuncOp funcOp : moduleOp.getOps()) { if (funcOp.isExternal()) continue; for (Operation& op : funcOp.getBody().front().getOperations()) { if (auto coreOp = dyn_cast(&op)) { if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp))) hasFailure = true; continue; } if (auto coreBatchOp = dyn_cast(&op)) { if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp))) hasFailure = true; continue; } if (auto returnOp = dyn_cast(&op)) { if (failed(verifyReturnOp(returnOp))) hasFailure = true; continue; } if (!isAddressOnlyHostOp(&op)) { op.emitOpError("illegal host-side runtime op remains after PIM bufferization; " "fold it to constants or lower it into pim.core"); hasFailure = true; continue; } if (failed(verifyAddressOnlyHostOp(&op))) hasFailure = true; } } if (hasFailure) signalPassFailure(); } private: template static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) { bool hasFailure = false; for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { auto getGlobalOp = weight.template getDefiningOp(); if (!getGlobalOp) { coreOp.emitOpError() << "weight #" << weightIndex << " must be materialized as memref.get_global before JSON codegen"; hasFailure = true; continue; } auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) { coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; hasFailure = true; continue; } if (!globalOp.getConstant() || !globalOp.getInitialValue()) { coreOp.emitOpError() << "weight #" << weightIndex << " must come from a constant memref.global with an initial value"; hasFailure = true; } } return success(!hasFailure); } static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { bool hasFailure = false; for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { if (!isCodegenAddressableValue(operand)) { returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage"; hasFailure = true; } } return success(!hasFailure); } template static LogicalResult verifyCoreOperands(CoreOpTy coreOp) { return walkPimCoreBlock( coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { bool hasFailure = false; for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { if (!isa(operand.getType())) continue; auto resolvedAddress = resolveContiguousAddress(operand, knowledge); if (failed(resolvedAddress)) { op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; hasFailure = true; continue; } if (isExplicitHostOperand(&op, operandIndex)) { if (!isCodegenAddressableValue(operand)) { op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage"; hasFailure = true; } continue; } if (!isa(resolvedAddress->base.getDefiningOp())) { op.emitOpError() << "operand #" << operandIndex << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; hasFailure = true; } } return success(!hasFailure); }); } static LogicalResult verifyAddressOnlyHostOp(Operation* op) { if (auto subviewOp = dyn_cast(op)) return verifyAddressOnlySource(op, subviewOp.getSource()); if (auto castOp = dyn_cast(op)) return verifyAddressOnlySource(op, castOp.getSource()); if (auto collapseOp = dyn_cast(op)) return verifyAddressOnlySource(op, collapseOp.getSrc()); if (auto expandOp = dyn_cast(op)) return verifyAddressOnlySource(op, expandOp.getSrc()); return success(); } static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { if (isCodegenAddressableValue(source)) return success(); op->emitOpError("depends on a value that is not backed by contiguous addressable storage"); return failure(); } }; } // namespace std::unique_ptr createPimVerificationPass() { return std::make_unique(); } } // namespace onnx_mlir