big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -18,7 +18,6 @@ namespace {
|
||||
|
||||
static bool isAddressOnlyHostOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
pim::PimEmptyManyOp,
|
||||
memref::AllocOp,
|
||||
memref::GetGlobalOp,
|
||||
memref::SubViewOp,
|
||||
@@ -37,12 +36,24 @@ static bool isBaseAddressableValue(Value value) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
if (!defOp)
|
||||
return false;
|
||||
if (isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(defOp))
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
|
||||
return true;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
value = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
value = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
value = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
value = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -52,7 +63,38 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
if (failed(resolvedAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
auto allStaticSubviewParts = [](memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
if (!defOp)
|
||||
return false;
|
||||
if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = getGlobalOp->getParentOfType<ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
return globalOp && globalOp.getConstant() && globalOp.getInitialValue()
|
||||
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
}
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
return false;
|
||||
value = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
value = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
@@ -125,13 +167,17 @@ private:
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as memref.get_global before JSON codegen";
|
||||
<< " must be materialized as a constant memref.global or a static view of one before JSON "
|
||||
"codegen";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
@@ -185,7 +231,7 @@ private:
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<pim::PimEmptyManyOp, memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
@@ -197,7 +243,7 @@ private:
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlySource(op, subviewOp.getSource());
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource());
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
@@ -221,6 +267,14 @@ private:
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
|
||||
if (isBaseAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user