use seed in validate.py for deterministic tests

This commit is contained in:
NiccoloN
2026-05-13 21:49:36 +02:00
parent 061139aefb
commit 55eda487dc
2 changed files with 4 additions and 2 deletions
+2 -2
View File
@@ -268,7 +268,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
reporter=None, model_index=1, model_total=1, verbose=False):
seed=0, reporter=None, model_index=1, model_total=1, verbose=False):
network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve()
@@ -306,7 +306,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs")
inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path)
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor)
inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor, seed=seed)
flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs")
print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}")
reporter.advance()