fix missed failing tests for channels
moderate refactor
This commit is contained in:
193
src/PIM/Pass/PimCodegen/VerificationPass.cpp
Normal file
193
src/PIM/Pass/PimCodegen/VerificationPass.cpp
Normal file
@@ -0,0 +1,193 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isAddressOnlyHostOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
memref::AllocOp,
|
||||
memref::GetGlobalOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||
}
|
||||
|
||||
VerificationPass() {}
|
||||
VerificationPass(const VerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure)
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as memref.get_global before JSON codegen";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlySource(op, subviewOp.getSource());
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user