From 57f0cca8c0173312a7544daec5877ca6ff4d9c85 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 11 May 2026 15:52:26 +0200 Subject: [PATCH] remove duplicated code quieter validation scripts (with optional verbose flag) --- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 614 +--------------------- validation/subprocess_utils.py | 31 +- validation/validate.py | 38 +- validation/validate_one.py | 30 +- 4 files changed, 69 insertions(+), 644 deletions(-) diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index f83fa4e..28c0dd9 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -17,532 +17,12 @@ namespace onnx_mlir { namespace spatial { namespace { - -enum class ListDelimiter { - Square, - Paren -}; - -static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - return onnx_mlir::compact_asm::parseOpenDelimiter( - parser, static_cast(delimiter)); -} - -static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - return onnx_mlir::compact_asm::parseOptionalCloseDelimiter( - parser, static_cast(delimiter)); -} - -static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - onnx_mlir::compact_asm::printOpenDelimiter( - printer, static_cast(delimiter)); -} - -static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - onnx_mlir::compact_asm::printCloseDelimiter( - printer, static_cast(delimiter)); -} +using namespace onnx_mlir::compact_asm; static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); } -template -static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& entries, - ParseEntryFn parseEntry) { - return onnx_mlir::compact_asm::parseCompressedRepeatedList( - parser, static_cast(delimiter), entries, parseEntry); -} - -template -static ParseResult -parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - return success(); - - while (true) { - if (succeeded(parser.parseOptionalLParen())) { - SmallVector subgroup; - if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup)) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - llvm::append_range(values, subgroup); - } - else { - int64_t first = 0; - if (parser.parseInteger(first)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("to"))) { - int64_t last = 0; - if (parser.parseInteger(last) || last < first) - return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); - - int64_t step = 1; - if (succeeded(parser.parseOptionalKeyword("by"))) { - if (parser.parseInteger(step) || step <= 0) - return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); - } - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - if ((last - first) % step != 0) - return parser.emitError(parser.getCurrentLocation(), - "range end must be reachable from start using the given step"); - - for (int64_t value = first; value <= last; value += step) - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(value)); - } - else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(first)); - } - } - - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); -} - -template -static ParseResult -parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { - if (parseOpenDelimiter(parser, delimiter)) - return failure(); - return parseCompressedIntegerEntries(parser, delimiter, values); -} - -template -static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { - for (size_t index = 0; index < entries.size();) { - size_t runEnd = index + 1; - while (runEnd < entries.size() && entries[runEnd] == entries[index]) - ++runEnd; - - if (index != 0) - printer << ", "; - printEntry(entries[index]); - size_t runLength = runEnd - index; - if (runLength > 1) - printer << " x" << runLength; - index = runEnd; - } -} - -template -static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef values, ListDelimiter delimiter) { - struct FlatCompression { - enum class Kind { - Single, - EqualRun, - Progression - }; - - Kind kind = Kind::Single; - size_t covered = 1; - size_t repeatCount = 1; - size_t progressionValueCount = 1; - int64_t step = 1; - IntT firstValue {}; - IntT lastValue {}; - }; - - auto computeFlatCompression = [&](size_t start) { - FlatCompression compression; - compression.firstValue = values[start]; - compression.lastValue = values[start]; - - auto findEqualRunEnd = [&](size_t runStart) { - size_t runEnd = runStart + 1; - while (runEnd < values.size() && values[runEnd] == values[runStart]) - ++runEnd; - return runEnd; - }; - - size_t firstRunEnd = findEqualRunEnd(start); - compression.repeatCount = firstRunEnd - start; - size_t progressionEnd = firstRunEnd; - int64_t step = 0; - IntT lastValue = values[start]; - - if (firstRunEnd < values.size()) { - size_t secondRunEnd = findEqualRunEnd(firstRunEnd); - step = static_cast(values[firstRunEnd]) - static_cast(values[start]); - if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) { - progressionEnd = secondRunEnd; - lastValue = values[firstRunEnd]; - size_t currentRunStart = secondRunEnd; - while (currentRunStart < values.size()) { - size_t currentRunEnd = findEqualRunEnd(currentRunStart); - if (currentRunEnd - currentRunStart != compression.repeatCount) - break; - if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) - break; - lastValue = values[currentRunStart]; - progressionEnd = currentRunEnd; - currentRunStart = currentRunEnd; - } - } - else { - step = 0; - } - } - - compression.covered = 1; - if (progressionEnd > firstRunEnd) { - size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount; - if (progressionValueCount >= 3) { - compression.kind = FlatCompression::Kind::Progression; - compression.covered = progressionEnd - start; - compression.progressionValueCount = progressionValueCount; - compression.step = step; - compression.lastValue = lastValue; - return compression; - } - } - - if (compression.repeatCount > 1) { - compression.kind = FlatCompression::Kind::EqualRun; - compression.covered = compression.repeatCount; - return compression; - } - - return compression; - }; - - auto findRepeatedSublist = [&](size_t start) { - size_t bestLength = 0; - size_t bestRepeatCount = 1; - size_t remaining = values.size() - start; - - for (size_t length = 2; length * 2 <= remaining; ++length) { - size_t repeatCount = 1; - ArrayRef candidate = values.slice(start, length); - while (start + (repeatCount + 1) * length <= values.size() - && llvm::equal(candidate, values.slice(start + repeatCount * length, length))) { - ++repeatCount; - } - - if (repeatCount <= 1) - continue; - - size_t covered = length * repeatCount; - size_t bestCovered = bestLength * bestRepeatCount; - if (covered > bestCovered || (covered == bestCovered && length < bestLength)) { - bestLength = length; - bestRepeatCount = repeatCount; - } - } - - return std::pair(bestLength, bestRepeatCount); - }; - - printOpenDelimiter(printer, delimiter); - for (size_t index = 0; index < values.size();) { - if (index != 0) - printer << ", "; - - FlatCompression flat = computeFlatCompression(index); - auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); - size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; - if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { - printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); - printer << " x" << sublistRepeatCount; - index += repeatedSublistCoverage; - continue; - } - switch (flat.kind) { - case FlatCompression::Kind::Progression: - printer << flat.firstValue << " to " << flat.lastValue; - if (flat.step != 1) - printer << " by " << flat.step; - if (flat.repeatCount > 1) - printer << " x" << flat.repeatCount; - index += flat.covered; - break; - case FlatCompression::Kind::EqualRun: - printer << flat.firstValue << " x" << flat.repeatCount; - index += flat.covered; - break; - case FlatCompression::Kind::Single: - printer << flat.firstValue; - index += flat.covered; - break; - } - } - printCloseDelimiter(printer, delimiter); -} - -template -static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { - return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values); -} - -template -static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { - printCompressedIntegerSequence(printer, values, ListDelimiter::Square); -} - -static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { - printOpenDelimiter(printer, delimiter); - for (size_t index = 0; index < values.size();) { - size_t equalRunEnd = index + 1; - while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) - ++equalRunEnd; - - if (index != 0) - printer << ", "; - if (equalRunEnd - index > 1) { - printer.printOperand(values[index]); - printer << " x" << (equalRunEnd - index); - index = equalRunEnd; - continue; - } - - size_t rangeEnd = index + 1; - if (auto firstResult = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextResult = dyn_cast(values[rangeEnd]); - if (!nextResult || nextResult.getOwner() != firstResult.getOwner() - || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - else if (auto firstArg = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextArg = dyn_cast(values[rangeEnd]); - if (!nextArg || nextArg.getOwner() != firstArg.getOwner() - || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - - printer.printOperand(values[index]); - if (rangeEnd - index >= 3) { - printer << " to "; - printer.printOperand(values[rangeEnd - 1]); - } - else if (rangeEnd - index == 2) { - printer << ", "; - printer.printOperand(values[index + 1]); - } - index = rangeEnd; - } - printCloseDelimiter(printer, delimiter); -} - -static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { - printOpenDelimiter(printer, delimiter); - printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); - printCloseDelimiter(printer, delimiter); -} - -static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, - SmallVectorImpl& operands); -static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, - SmallVectorImpl& operands); -static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty); - -static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { - return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize); -} - -static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { - return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize); -} - -static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) { - onnx_mlir::compact_asm::printValueTupleRun( - printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square); -} - -static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) { - onnx_mlir::compact_asm::printTypeTupleRun( - printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square); -} - -static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, - SmallVectorImpl& operands) { - return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList( - parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands); -} - -static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl& types) { - return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList( - parser, onnx_mlir::compact_asm::ListDelimiter::Square, types); -} - -static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, - OpAsmParser::UnresolvedOperand firstOperand, - SmallVectorImpl& operands) { - if (succeeded(parser.parseOptionalKeyword("to"))) { - OpAsmParser::UnresolvedOperand lastOperand; - if (parser.parseOperand(lastOperand)) - return failure(); - if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) - return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); - for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) - operands.push_back({firstOperand.location, firstOperand.name, number}); - } - else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - operands.push_back(firstOperand); - } - - return success(); -} - -static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, - SmallVectorImpl& operands) { - OpAsmParser::UnresolvedOperand firstOperand; - if (parser.parseOperand(firstOperand)) - return failure(); - return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); -} - -static ParseResult parseCompressedOperandList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& operands) { - if (parseOpenDelimiter(parser, delimiter)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - return success(); - - while (true) { - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); -} - -static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, - SmallVectorImpl& operands) { - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - while (succeeded(parser.parseOptionalComma())) - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - return success(); -} - -static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { - for (size_t index = 0; index < values.size();) { - size_t equalRunEnd = index + 1; - while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) - ++equalRunEnd; - - if (index != 0) - printer << ", "; - if (equalRunEnd - index > 1) { - printer.printOperand(values[index]); - printer << " x" << (equalRunEnd - index); - index = equalRunEnd; - continue; - } - - size_t rangeEnd = index + 1; - if (auto firstResult = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextResult = dyn_cast(values[rangeEnd]); - if (!nextResult || nextResult.getOwner() != firstResult.getOwner() - || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - else if (auto firstArg = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextArg = dyn_cast(values[rangeEnd]); - if (!nextArg || nextArg.getOwner() != firstArg.getOwner() - || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - - printer.printOperand(values[index]); - if (rangeEnd - index >= 3) { - printer << " to "; - printer.printOperand(values[rangeEnd - 1]); - } - else if (rangeEnd - index == 2) { - printer << ", "; - printer.printOperand(values[index + 1]); - } - index = rangeEnd; - } -} - -static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { - printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); -} - -static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { - Type firstType; - OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); - if (!firstTypeResult.has_value()) { - if (allowEmpty) - return success(); - return parser.emitError(parser.getCurrentLocation(), "expected type"); - } - if (failed(*firstTypeResult)) - return failure(); - - auto appendType = [&](Type type) -> ParseResult { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - types.push_back(type); - return success(); - }; - - if (appendType(firstType)) - return failure(); - - while (succeeded(parser.parseOptionalComma())) { - Type nextType; - if (parser.parseType(nextType) || appendType(nextType)) - return failure(); - } - - return success(); -} - static void printChannelMetadata(OpAsmPrinter& printer, ArrayRef channelIds, ArrayRef sourceCoreIds, @@ -567,90 +47,6 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } -static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) { - if (block.getNumArguments() == 0) { - printer << "() = ()"; - return; - } - - if (block.getNumArguments() == 1) { - printer.printOperand(block.getArgument(0)); - printer << " = "; - printCompressedValueList(printer, operands, ListDelimiter::Paren); - return; - } - - printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren); - printer << " = "; - printCompressedValueList(printer, operands, ListDelimiter::Paren); -} - -static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser, - OpAsmParser::Argument firstArgument, - SmallVectorImpl& arguments) { - if (succeeded(parser.parseOptionalKeyword("to"))) { - OpAsmParser::Argument lastArgument; - if (parser.parseArgument(lastArgument)) - return failure(); - if (firstArgument.ssaName.name != lastArgument.ssaName.name - || firstArgument.ssaName.number > lastArgument.ssaName.number) { - return parser.emitError(parser.getCurrentLocation(), "invalid argument range"); - } - for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) { - OpAsmParser::Argument argument; - argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number}; - arguments.push_back(argument); - } - return success(); - } - - arguments.push_back(firstArgument); - return success(); -} - -static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser, - SmallVectorImpl& arguments) { - OpAsmParser::Argument firstArgument; - if (parser.parseArgument(firstArgument)) - return failure(); - return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments); -} - -static void applyArgumentTypes(ArrayRef inputTypes, SmallVectorImpl& arguments) { - for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes)) - argument.type = inputType; -} - -static ParseResult parseArgumentBindings(OpAsmParser& parser, - SmallVectorImpl& arguments, - SmallVectorImpl& operands) { - if (succeeded(parser.parseOptionalLParen())) { - if (succeeded(parser.parseOptionalRParen())) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) - return failure(); - return success(); - } - - OpAsmParser::Argument firstArgument; - if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments)) - return failure(); - while (succeeded(parser.parseOptionalComma())) - if (parseOneCompressedArgumentEntry(parser, arguments)) - return failure(); - if (parser.parseRParen() || parser.parseEqual() - || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) - return failure(); - return success(); - } - - OpAsmParser::Argument argument; - if (parser.parseArgument(argument) || parser.parseEqual() - || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) - return failure(); - arguments.push_back(argument); - return success(); -} - } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { @@ -875,7 +271,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " lanes " << getLaneCount() << " "; size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) - printValueTupleRun(printer, getWeights(), weightsPerLane); + printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square); else printCompressedValueList(printer, getWeights(), ListDelimiter::Square); printer << " "; @@ -892,7 +288,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " : "; if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) - printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane); + printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square); else printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; @@ -916,7 +312,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) return failure(); - if (parseCompressedOrTupleOperandList(parser, weights)) + if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights)) return failure(); if (parseArgumentBindings(parser, regionArgs, inputs)) @@ -927,7 +323,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedOrTupleTypeList(parser, weightTypes) + || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) diff --git a/validation/subprocess_utils.py b/validation/subprocess_utils.py index de69ee0..2c2225c 100644 --- a/validation/subprocess_utils.py +++ b/validation/subprocess_utils.py @@ -16,7 +16,7 @@ def _read_chunk(fd, treat_eio_as_eof=False): raise -def _stream_output(fd, process, reporter, treat_eio_as_eof=False): +def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True): selector = selectors.DefaultSelector() recent_output = bytearray() captured_output = bytearray() @@ -32,19 +32,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False): os.close(key.fileobj) continue - reporter._clear() - os.write(1, data) - reporter._render() + if stream_output: + reporter._clear() + os.write(1, data) + reporter._render() captured_output.extend(data) - recent_output.extend(data) - if len(recent_output) > MAX_ERROR_OUTPUT_BYTES: - del recent_output[:-MAX_ERROR_OUTPUT_BYTES] + if stream_output: + recent_output.extend(data) + if len(recent_output) > MAX_ERROR_OUTPUT_BYTES: + del recent_output[:-MAX_ERROR_OUTPUT_BYTES] finally: selector.close() return_code = process.wait() if return_code != 0: - raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output)) + error_output = captured_output if not stream_output else recent_output + raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output)) return bytes(captured_output) @@ -62,6 +65,18 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False subprocess.run(cmd, cwd=cwd, check=True) return None + stream_output = bool(getattr(reporter, "verbose", False)) + if not stream_output: + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert process.stdout is not None + output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False) + return output.decode("utf-8", errors="replace") if capture_output else None + try: master_fd, slave_fd = pty.openpty() except OSError: diff --git a/validation/validate.py b/validation/validate.py index e3c8d6c..9121efe 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -27,6 +27,10 @@ def print_validation_error(reporter, rel, exc): file=sys.stderr, flush=True) if isinstance(exc, subprocess.CalledProcessError): print(format_return_status(exc.returncode), file=sys.stderr, flush=True) + if exc.output: + output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output) + if output_text: + print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True) else: print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True) print("=" * 72, file=sys.stderr, flush=True) @@ -60,6 +64,8 @@ def main(): help="Core count to pass to Raptor. If omitted, Raptor uses its default.") ap.add_argument("--clean", action="store_true", help="Remove generated validation artifacts under each model workspace and exit.") + ap.add_argument("--verbose", action="store_true", + help="Print per-stage progress and subprocess logs for passing validations too.") a = ap.parse_args() script_dir = Path(__file__).parent.resolve() @@ -101,7 +107,7 @@ def main(): pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS} total_timing_sum = 0.0 timed_benchmark_count = 0 - reporter = ProgressReporter(len(onnx_files)) + reporter = ProgressReporter(len(onnx_files), verbose=a.verbose) for index, onnx_path in enumerate(onnx_files, start=1): rel = onnx_path.relative_to(operations_dir) try: @@ -112,6 +118,7 @@ def main(): reporter=reporter, model_index=index, model_total=len(onnx_files), + verbose=a.verbose, ) results[str(rel)] = result.passed if result.pim_pass_timings: @@ -134,22 +141,25 @@ def main(): # Summary n_passed = sum(1 for passed in results.values() if passed) n_total = len(results) - status_width = len("Result") - path_width = max(len("Operation"), *(len(rel) for rel in results)) - separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+" - print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL) - print(separator) - print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |") - print(separator) - for rel, passed in results.items(): - plain_status = "PASS" if passed else "FAIL" - status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \ - Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL - print(f"| {rel.ljust(path_width)} | {status} |") - print(separator) print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL) print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL) + failing = [rel for rel, passed in results.items() if not passed] + if a.verbose or failing: + status_width = len("Result") + path_width = max(len("Operation"), *(len(rel) for rel in results)) + separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+" + print(separator) + print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |") + print(separator) + for rel, passed in results.items(): + if not a.verbose and passed: + continue + plain_status = "PASS" if passed else "FAIL" + status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \ + Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL + print(f"| {rel.ljust(path_width)} | {status} |") + print(separator) print_average_pim_pass_timings( pass_timing_sums, pass_timing_counts, diff --git a/validation/validate_one.py b/validation/validate_one.py index cd5c1e7..aa796cc 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -36,7 +36,7 @@ class ValidationResult: class ProgressReporter: - def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None): + def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None, verbose=False): self.total_models = total_models self.stages_per_model = stages_per_model self.total_steps = max(1, total_models * stages_per_model) @@ -45,6 +45,7 @@ class ProgressReporter: self.failed_models = 0 self.current_label = "" self.enabled = sys.stdout.isatty() if enabled is None else enabled + self.verbose = verbose self.columns = shutil.get_terminal_size((100, 20)).columns self.suspended = False @@ -96,6 +97,8 @@ class ProgressReporter: sys.stdout.flush() def log(self, message="", color=None): + if not self.enabled and not self.verbose: + return if self.enabled: self._clear() if color: @@ -228,7 +231,7 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor): return arrays -def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3): +def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3, verbose=False): all_passed = True rows = [] for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): @@ -245,26 +248,27 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1 result_width = len("Result") separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+" - print(separator) - print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |") - print(separator) - for name, diff_text, passed in rows: - status_text = ("PASS" if passed else "FAIL").ljust(result_width) - status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL - print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |") - print(separator) + if verbose or not all_passed: + print(separator) + print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |") + print(separator) + for name, diff_text, passed in rows: + status_text = ("PASS" if passed else "FAIL").ljust(result_width) + status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL + print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |") + print(separator) return all_passed def validate_network(network_onnx_path, raptor_path, onnx_include_dir, simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3, - reporter=None, model_index=1, model_total=1): + reporter=None, model_index=1, model_total=1, verbose=False): network_onnx_path = Path(network_onnx_path).resolve() raptor_path = Path(raptor_path).resolve() onnx_include_dir = Path(onnx_include_dir).resolve() simulator_dir = Path(simulator_dir).resolve() owns_reporter = reporter is None - reporter = reporter or ProgressReporter(model_total) + reporter = reporter or ProgressReporter(model_total, verbose=verbose) workspace_dir = network_onnx_path.parent clean_workspace_artifacts(workspace_dir, network_onnx_path.stem) @@ -331,7 +335,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs") sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor) reporter.suspend() - passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) + passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold, verbose=verbose) reporter.resume() reporter.advance() reporter.record_result(passed)