This commit is contained in:
@@ -13,10 +13,14 @@ using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
static bool isCompactDeviceLocalExecutableMemRef(Value memrefValue) {
|
||||
auto resolved = resolveContiguousAddress(memrefValue);
|
||||
return succeeded(resolved) && resolved->byteOffset == 0
|
||||
&& isa<memref::AllocOp>(resolved->base.getDefiningOp());
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
if (isCompactDeviceLocalExecutableMemRef(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
@@ -44,9 +48,7 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
||||
}
|
||||
|
||||
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
if (isCompactDeviceLocalExecutableMemRef(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
|
||||
@@ -29,6 +29,13 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isCompactExecutableRuntimeMemRef(Value value) {
|
||||
if (!isa<BaseMemRefType>(value.getType()))
|
||||
return false;
|
||||
auto resolved = resolveContiguousAddress(value);
|
||||
return succeeded(resolved) && resolved->byteOffset == 0 && isa<memref::AllocOp>(resolved->base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct MemRefCopyWorkItem {
|
||||
memref::CopyOp copyOp;
|
||||
StaticValueKnowledge knowledge;
|
||||
@@ -271,10 +278,10 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
auto verifyOperand = [&](Value operand, unsigned operandIndex) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
return;
|
||||
if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand)))
|
||||
if (isCompactExecutableRuntimeMemRef(operand))
|
||||
return;
|
||||
op->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage after PIM bufferization";
|
||||
<< " must be a compact device-local memref after PIM bufferization";
|
||||
hasFailure = true;
|
||||
};
|
||||
|
||||
@@ -283,8 +290,16 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(memCopyOp.getTarget(), 0);
|
||||
verifyOperand(memCopyOp.getSource(), 1);
|
||||
if (failed(resolveContiguousAddress(memCopyOp.getTarget()))
|
||||
&& failed(compileContiguousAddressExpr(memCopyOp.getTarget()))) {
|
||||
memCopyOp.emitOpError("target is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveContiguousAddress(memCopyOp.getSource()))
|
||||
&& failed(compileContiguousAddressExpr(memCopyOp.getSource()))) {
|
||||
memCopyOp.emitOpError("source is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
||||
@@ -292,8 +307,17 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(loadOp.getDeviceTarget(), 2);
|
||||
verifyOperand(loadOp.getHostSource(), 3);
|
||||
if (failed(resolveContiguousAddress(loadOp.getDeviceTarget()))
|
||||
&& failed(compileContiguousAddressExpr(loadOp.getDeviceTarget()))) {
|
||||
loadOp.emitOpError(
|
||||
"device target is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveContiguousAddress(loadOp.getHostSource()))
|
||||
&& failed(compileContiguousAddressExpr(loadOp.getHostSource()))) {
|
||||
loadOp.emitOpError("host source is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
||||
@@ -301,8 +325,17 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(storeOp.getHostTarget(), 2);
|
||||
verifyOperand(storeOp.getDeviceSource(), 3);
|
||||
if (failed(resolveContiguousAddress(storeOp.getHostTarget()))
|
||||
&& failed(compileContiguousAddressExpr(storeOp.getHostTarget()))) {
|
||||
storeOp.emitOpError("host target is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveContiguousAddress(storeOp.getDeviceSource()))
|
||||
&& failed(compileContiguousAddressExpr(storeOp.getDeviceSource()))) {
|
||||
storeOp.emitOpError(
|
||||
"device source is not backed by contiguous addressable storage after PIM bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto sendOp = dyn_cast<PimSendOp>(op)) {
|
||||
@@ -340,7 +373,8 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
});
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen");
|
||||
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operands to compact "
|
||||
"device-local memrefs before codegen");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
||||
@@ -46,6 +46,13 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isCompactExecutableRuntimeMemRef(Value value, const StaticValueKnowledge& knowledge) {
|
||||
if (!isa<BaseMemRefType>(value.getType()))
|
||||
return false;
|
||||
auto resolved = resolveContiguousAddress(value, knowledge);
|
||||
return succeeded(resolved) && resolved->byteOffset == 0 && isa<memref::AllocOp>(resolved->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
@@ -336,6 +343,14 @@ private:
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (!isCompactExecutableRuntimeMemRef(operand, knowledge)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " must be a compact device-local memref; subviews and offset views are "
|
||||
"not legal for executable PIM ops";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1452,6 +1452,85 @@ def slice_large_channel_1024():
|
||||
save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx")
|
||||
|
||||
|
||||
def slice_nonzero_channel_offset_add():
|
||||
"""Add two channel slices where one operand starts at a non-zero channel offset."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4, 16])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 16])
|
||||
starts0 = make_int64_initializer("starts0", [0])
|
||||
ends0 = make_int64_initializer("ends0", [2])
|
||||
starts1 = make_int64_initializer("starts1", [2])
|
||||
ends1 = make_int64_initializer("ends1", [4])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
slice0 = helper.make_node("Slice", ["X", "starts0", "ends0", "axes"], ["S0"])
|
||||
slice1 = helper.make_node("Slice", ["X", "starts1", "ends1", "axes"], ["S1"])
|
||||
add = helper.make_node("Add", ["S0", "S1"], ["Y"])
|
||||
graph = helper.make_graph(
|
||||
[slice0, slice1, add], "slice_nonzero_channel_offset_add", [X], [Y],
|
||||
initializer=[starts0, ends0, starts1, ends1, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nonzero_channel_offset_add", "slice_nonzero_channel_offset_add.onnx")
|
||||
|
||||
|
||||
def slice_nonzero_channel_offset_sub():
|
||||
"""Sub two channel slices where one operand starts at a non-zero channel offset."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4, 16])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 16])
|
||||
starts0 = make_int64_initializer("starts0", [0])
|
||||
ends0 = make_int64_initializer("ends0", [2])
|
||||
starts1 = make_int64_initializer("starts1", [2])
|
||||
ends1 = make_int64_initializer("ends1", [4])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
slice0 = helper.make_node("Slice", ["X", "starts0", "ends0", "axes"], ["S0"])
|
||||
slice1 = helper.make_node("Slice", ["X", "starts1", "ends1", "axes"], ["S1"])
|
||||
sub = helper.make_node("Sub", ["S0", "S1"], ["Y"])
|
||||
graph = helper.make_graph(
|
||||
[slice0, slice1, sub], "slice_nonzero_channel_offset_sub", [X], [Y],
|
||||
initializer=[starts0, ends0, starts1, ends1, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nonzero_channel_offset_sub", "slice_nonzero_channel_offset_sub.onnx")
|
||||
|
||||
|
||||
def slice_nonzero_channel_offset_mul():
|
||||
"""Mul two channel slices where one operand starts at a non-zero channel offset."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4, 16])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 16])
|
||||
starts0 = make_int64_initializer("starts0", [0])
|
||||
ends0 = make_int64_initializer("ends0", [2])
|
||||
starts1 = make_int64_initializer("starts1", [2])
|
||||
ends1 = make_int64_initializer("ends1", [4])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
slice0 = helper.make_node("Slice", ["X", "starts0", "ends0", "axes"], ["S0"])
|
||||
slice1 = helper.make_node("Slice", ["X", "starts1", "ends1", "axes"], ["S1"])
|
||||
mul = helper.make_node("Mul", ["S0", "S1"], ["Y"])
|
||||
graph = helper.make_graph(
|
||||
[slice0, slice1, mul], "slice_nonzero_channel_offset_mul", [X], [Y],
|
||||
initializer=[starts0, ends0, starts1, ends1, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nonzero_channel_offset_mul", "slice_nonzero_channel_offset_mul.onnx")
|
||||
|
||||
|
||||
def slice_yolo_like_decode_tail():
|
||||
"""YOLO-like decode tail using two channel slices with a non-zero-offset slice feeding Sub/Add/Sub/Concat."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4, 16])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 16])
|
||||
starts0 = make_int64_initializer("starts0", [0])
|
||||
ends0 = make_int64_initializer("ends0", [2])
|
||||
starts1 = make_int64_initializer("starts1", [2])
|
||||
ends1 = make_int64_initializer("ends1", [4])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
slice0 = helper.make_node("Slice", ["X", "starts0", "ends0", "axes"], ["S0"])
|
||||
slice1 = helper.make_node("Slice", ["X", "starts1", "ends1", "axes"], ["S1"])
|
||||
sub0 = helper.make_node("Sub", ["S1", "S0"], ["D0"])
|
||||
add0 = helper.make_node("Add", ["S0", "S1"], ["A0"])
|
||||
sub1 = helper.make_node("Sub", ["A0", "D0"], ["D1"])
|
||||
concat = helper.make_node("Concat", ["D0", "D1"], ["Y"], axis=1)
|
||||
graph = helper.make_graph(
|
||||
[slice0, slice1, sub0, add0, sub1, concat], "slice_yolo_like_decode_tail", [X], [Y],
|
||||
initializer=[starts0, ends0, starts1, ends1, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/yolo_like_decode_tail", "slice_yolo_like_decode_tail.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gather tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -2001,6 +2080,10 @@ if __name__ == "__main__":
|
||||
slice_nchw_spatial_crop()
|
||||
slice_after_conv()
|
||||
slice_large_channel_1024()
|
||||
slice_nonzero_channel_offset_add()
|
||||
slice_nonzero_channel_offset_sub()
|
||||
slice_nonzero_channel_offset_mul()
|
||||
slice_yolo_like_decode_tail()
|
||||
|
||||
print("\nGenerating Softmax tests:")
|
||||
softmax_basic()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user