remove duplicated code
Validate Operations / validate-operations (push) Has been cancelled

quieter validation scripts (with optional verbose flag)
This commit is contained in:
NiccoloN
2026-05-11 15:52:26 +02:00
parent 5ff364027b
commit 57f0cca8c0
4 changed files with 69 additions and 644 deletions
+5 -609
View File
@@ -17,532 +17,12 @@ namespace onnx_mlir {
namespace spatial {
namespace {
enum class ListDelimiter {
Square,
Paren
};
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
return onnx_mlir::compact_asm::parseOpenDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
onnx_mlir::compact_asm::printOpenDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
onnx_mlir::compact_asm::printCloseDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
using namespace onnx_mlir::compact_asm;
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
template <typename EntryT, typename ParseEntryFn>
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
}
template <typename IntT>
static ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0)
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
template <typename IntT>
static ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename IntT>
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
printer << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
printer << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
printer << " by " << flat.step;
if (flat.repeatCount > 1)
printer << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
printer << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
printer << flat.firstValue;
index += flat.covered;
break;
}
}
printCloseDelimiter(printer, delimiter);
}
template <typename IntT>
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
printCloseDelimiter(printer, delimiter);
}
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
printCloseDelimiter(printer, delimiter);
}
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
}
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
}
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
onnx_mlir::compact_asm::printValueTupleRun(
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
onnx_mlir::compact_asm::printTypeTupleRun(
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
}
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
}
return success();
}
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
}
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
break;
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
@@ -567,90 +47,6 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
static ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
arguments.push_back(argument);
return success();
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -875,7 +271,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane);
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " ";
@@ -892,7 +288,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
@@ -916,7 +312,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
return failure();
if (parseCompressedOrTupleOperandList(parser, weights))
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
return failure();
if (parseArgumentBindings(parser, regionArgs, inputs))
@@ -927,7 +323,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
+17 -2
View File
@@ -16,7 +16,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
raise
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True):
selector = selectors.DefaultSelector()
recent_output = bytearray()
captured_output = bytearray()
@@ -32,10 +32,12 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
os.close(key.fileobj)
continue
if stream_output:
reporter._clear()
os.write(1, data)
reporter._render()
captured_output.extend(data)
if stream_output:
recent_output.extend(data)
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
@@ -44,7 +46,8 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
return_code = process.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
error_output = captured_output if not stream_output else recent_output
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
return bytes(captured_output)
@@ -62,6 +65,18 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
subprocess.run(cmd, cwd=cwd, check=True)
return None
stream_output = bool(getattr(reporter, "verbose", False))
if not stream_output:
process = subprocess.Popen(
cmd,
cwd=cwd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
assert process.stdout is not None
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False)
return output.decode("utf-8", errors="replace") if capture_output else None
try:
master_fd, slave_fd = pty.openpty()
except OSError:
+15 -5
View File
@@ -27,6 +27,10 @@ def print_validation_error(reporter, rel, exc):
file=sys.stderr, flush=True)
if isinstance(exc, subprocess.CalledProcessError):
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
if exc.output:
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
if output_text:
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
else:
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
print("=" * 72, file=sys.stderr, flush=True)
@@ -60,6 +64,8 @@ def main():
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
ap.add_argument("--clean", action="store_true",
help="Remove generated validation artifacts under each model workspace and exit.")
ap.add_argument("--verbose", action="store_true",
help="Print per-stage progress and subprocess logs for passing validations too.")
a = ap.parse_args()
script_dir = Path(__file__).parent.resolve()
@@ -101,7 +107,7 @@ def main():
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
total_timing_sum = 0.0
timed_benchmark_count = 0
reporter = ProgressReporter(len(onnx_files))
reporter = ProgressReporter(len(onnx_files), verbose=a.verbose)
for index, onnx_path in enumerate(onnx_files, start=1):
rel = onnx_path.relative_to(operations_dir)
try:
@@ -112,6 +118,7 @@ def main():
reporter=reporter,
model_index=index,
model_total=len(onnx_files),
verbose=a.verbose,
)
results[str(rel)] = result.passed
if result.pim_pass_timings:
@@ -134,22 +141,25 @@ def main():
# Summary
n_passed = sum(1 for passed in results.values() if passed)
n_total = len(results)
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
failing = [rel for rel, passed in results.items() if not passed]
if a.verbose or failing:
status_width = len("Result")
path_width = max(len("Operation"), *(len(rel) for rel in results))
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
print(separator)
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
print(separator)
for rel, passed in results.items():
if not a.verbose and passed:
continue
plain_status = "PASS" if passed else "FAIL"
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
print(f"| {rel.ljust(path_width)} | {status} |")
print(separator)
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
print_average_pim_pass_timings(
pass_timing_sums,
pass_timing_counts,
+9 -5
View File
@@ -36,7 +36,7 @@ class ValidationResult:
class ProgressReporter:
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None):
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None, verbose=False):
self.total_models = total_models
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
@@ -45,6 +45,7 @@ class ProgressReporter:
self.failed_models = 0
self.current_label = ""
self.enabled = sys.stdout.isatty() if enabled is None else enabled
self.verbose = verbose
self.columns = shutil.get_terminal_size((100, 20)).columns
self.suspended = False
@@ -96,6 +97,8 @@ class ProgressReporter:
sys.stdout.flush()
def log(self, message="", color=None):
if not self.enabled and not self.verbose:
return
if self.enabled:
self._clear()
if color:
@@ -228,7 +231,7 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
return arrays
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3, verbose=False):
all_passed = True
rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
@@ -245,6 +248,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
result_width = len("Result")
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
if verbose or not all_passed:
print(separator)
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
print(separator)
@@ -258,13 +262,13 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
reporter=None, model_index=1, model_total=1):
reporter=None, model_index=1, model_total=1, verbose=False):
network_onnx_path = Path(network_onnx_path).resolve()
raptor_path = Path(raptor_path).resolve()
onnx_include_dir = Path(onnx_include_dir).resolve()
simulator_dir = Path(simulator_dir).resolve()
owns_reporter = reporter is None
reporter = reporter or ProgressReporter(model_total)
reporter = reporter or ProgressReporter(model_total, verbose=verbose)
workspace_dir = network_onnx_path.parent
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
@@ -331,7 +335,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
reporter.suspend()
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold, verbose=verbose)
reporter.resume()
reporter.advance()
reporter.record_result(passed)