remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled

avoid spammy pim codegen diagnostics
This commit is contained in:
NiccoloN
2026-06-05 10:06:28 +02:00
parent 27410207c4
commit 1e9e61f5a9
20 changed files with 458 additions and 256 deletions
-1
View File
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/MemoryCoalescing)
add_subdirectory(Transforms/HostConstantFolding)
add_subdirectory(Transforms/HostConstantMaterialization)
add_subdirectory(Transforms/Verification)
add_pim_library(PimOps
@@ -6,6 +6,7 @@
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir;
using namespace bufferization;
@@ -13,7 +14,9 @@ using namespace bufferization;
namespace onnx_mlir::pim {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
if (failed(sizeAttr))
return failure();
if (isHostBackedPimAddress(memrefValue)) {
return PimMemCopyHostToDevOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -1,9 +1,70 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
static bool isCoreBatchInputArgument(Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg)
return false;
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
if (!coreBatchOp)
return false;
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
}
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
llvm::SmallPtrSet<Value, 8> visited;
while (value && visited.insert(value).second) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
if (auto aliased = knowledge.aliases.lookup(value)) {
value = aliased;
continue;
}
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
return base;
if (isa<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
return value;
}
if (value)
return value;
return failure();
}
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
return failure();
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
}
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
auto resolved = resolveContiguousAddress(value, knowledge);
if (succeeded(resolved))
return resolved->base;
auto compiled = compileContiguousAddressExpr(value);
if (failed(compiled)) {
if (isa<BlockArgument>(value))
return value;
return failure();
}
return compiled->base;
}
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
if (isCoreBatchInputArgument(*base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
}
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
}
@@ -2,11 +2,19 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
namespace pim {
mlir::FailureOr<mlir::IntegerAttr>
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
} // namespace pim
} // namespace onnx_mlir
@@ -8,6 +8,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "Common/PimCommon.hpp"
@@ -15,6 +16,7 @@
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
@@ -27,24 +29,71 @@ namespace onnx_mlir {
namespace {
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
struct MemRefCopyWorkItem {
memref::CopyOp copyOp;
StaticValueKnowledge knowledge;
};
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
static LogicalResult
lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
if (targetIsDevice && sourceIsHost) {
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsHost && sourceIsDevice) {
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsDevice && sourceIsDevice) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
rewriter.eraseOp(copyOp);
return success();
}
};
else {
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
<< " device=" << targetIsDevice;
return failure();
}
rewriter.eraseOp(copyOp);
return success();
}
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
}
MLIRContext* ctx = moduleOp.getContext();
RewritePatternSet memrefCopyPatterns(ctx);
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
memrefCopyApplicator.applyDefaultCostModel();
PatternRewriter rewriter(ctx);
SmallVector<memref::CopyOp> copyWorklist;
moduleOp.walk([&](memref::CopyOp copyOp) {
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
copyWorklist.push_back(copyOp);
SmallVector<MemRefCopyWorkItem> copyWorklist;
llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
if (seenCopyOps.insert(copyOp.getOperation()).second)
copyWorklist.push_back({copyOp, knowledge});
};
moduleOp.walk([&](pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
(void) walkPimCoreBlockStructurally(
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
});
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
llvm::SmallVector<unsigned, 2> lanes;
lanes.push_back(0);
if (coreBatchOp.getLaneCount() > 1)
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
for (unsigned lane : lanes) {
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
(void) walkPimCoreBlockStructurally(
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
}
});
bool hasFailed = false;
for (memref::CopyOp copyOp : copyWorklist) {
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
for (const MemRefCopyWorkItem& workItem : copyWorklist) {
memref::CopyOp copyOp = workItem.copyOp;
rewriter.setInsertionPoint(copyOp);
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
hasFailed = true;
}
}
if (hasFailed) {
signalPassFailure();
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
if (failed(sizeAttr))
return failure();
pim::PimMemCopyOp::create(
pim::PimMemCopyHostToDevOp::create(
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
rewriter.eraseOp(mapOp);
return success();
@@ -1,9 +0,0 @@
add_pim_library(OMPimHostConstantMaterialization
MaterializeHostConstantsPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -1,161 +0,0 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
DominanceInfo dominance(coreOp);
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
ops.push_back(op);
});
for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op->emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
operand.set(cachedValue->second);
continue;
}
auto type = dyn_cast<ShapedType>(originalValue.getType());
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
: FailureOr<uint64_t>(failure());
auto totalBytesAttr =
succeeded(totalBytes)
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(totalBytesAttr)
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(op);
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
originalType,
zeroOffset,
hostOffset,
deviceDst,
getGlobalOp.getResult(),
*totalBytesAttr)
.getOutput();
cachedByType[originalType] = copiedValue;
operand.set(copiedValue);
}
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim4_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
while (true) {
Operation* defOp = value.getDefiningOp();
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op);
}
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
Value base;
if (succeeded(resolvedAddress)) {
base = resolvedAddress->base;
}
else {
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
base = compiledAddress->base;
}
if (isa<BlockArgument>(base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -311,10 +316,10 @@ private:
}
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) {
if (!isHostAddressableValue(operand, knowledge)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
<< " must be backed by host-addressable storage";
});
hasFailure = true;
}