big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
+65 -11
View File
@@ -18,7 +18,6 @@ namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp,
pim::PimEmptyManyOp,
memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
@@ -37,12 +36,24 @@ static bool isBaseAddressableValue(Value value) {
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(defOp))
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
return true;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
value = expand.getSrc();
continue;
}
return false;
}
}
@@ -52,7 +63,38 @@ static bool isCodegenAddressableValue(Value value) {
if (failed(resolvedAddress))
return false;
return isa<BlockArgument>(resolvedAddress->base)
|| isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
auto allStaticSubviewParts = [](memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
};
while (true) {
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)) {
auto moduleOp = getGlobalOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant() && globalOp.getInitialValue()
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
}
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
return false;
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
return false;
}
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
@@ -125,13 +167,17 @@ private:
bool hasFailure = false;
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
if (!getGlobalOp && !isConstantGlobalView(weight)) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as memref.get_global before JSON codegen";
<< " must be materialized as a constant memref.global or a static view of one before JSON "
"codegen";
hasFailure = true;
continue;
}
if (!getGlobalOp)
continue;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
@@ -185,7 +231,7 @@ private:
continue;
}
if (!isa<pim::PimEmptyManyOp, memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
@@ -197,7 +243,7 @@ private:
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlySource(op, subviewOp.getSource());
return verifyAddressOnlyBase(op, subviewOp.getSource());
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource());
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
@@ -221,6 +267,14 @@ private:
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
return failure();
}
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
if (isBaseAddressableValue(source))
return success();
op->emitOpError("depends on a value that is not backed by addressable storage");
return failure();
}
};
} // namespace