quieter validation scripts (with optional verbose flag)
This commit is contained in:
@@ -17,532 +17,12 @@ namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ListDelimiter {
|
||||
Square,
|
||||
Paren
|
||||
};
|
||||
|
||||
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
return onnx_mlir::compact_asm::parseOpenDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
onnx_mlir::compact_asm::printOpenDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
onnx_mlir::compact_asm::printCloseDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
template <typename EntryT, typename ParseEntryFn>
|
||||
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<EntryT>& entries,
|
||||
ParseEntryFn parseEntry) {
|
||||
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult
|
||||
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<IntT> subgroup;
|
||||
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(values, subgroup);
|
||||
}
|
||||
else {
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
}
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult
|
||||
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||
}
|
||||
|
||||
template <typename RangeT, typename PrintEntryFn>
|
||||
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||
for (size_t index = 0; index < entries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||
++runEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printEntry(entries[index]);
|
||||
size_t runLength = runEnd - index;
|
||||
if (runLength > 1)
|
||||
printer << " x" << runLength;
|
||||
index = runEnd;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||
struct FlatCompression {
|
||||
enum class Kind {
|
||||
Single,
|
||||
EqualRun,
|
||||
Progression
|
||||
};
|
||||
|
||||
Kind kind = Kind::Single;
|
||||
size_t covered = 1;
|
||||
size_t repeatCount = 1;
|
||||
size_t progressionValueCount = 1;
|
||||
int64_t step = 1;
|
||||
IntT firstValue {};
|
||||
IntT lastValue {};
|
||||
};
|
||||
|
||||
auto computeFlatCompression = [&](size_t start) {
|
||||
FlatCompression compression;
|
||||
compression.firstValue = values[start];
|
||||
compression.lastValue = values[start];
|
||||
|
||||
auto findEqualRunEnd = [&](size_t runStart) {
|
||||
size_t runEnd = runStart + 1;
|
||||
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||
++runEnd;
|
||||
return runEnd;
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(start);
|
||||
compression.repeatCount = firstRunEnd - start;
|
||||
size_t progressionEnd = firstRunEnd;
|
||||
int64_t step = 0;
|
||||
IntT lastValue = values[start];
|
||||
|
||||
if (firstRunEnd < values.size()) {
|
||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||
progressionEnd = secondRunEnd;
|
||||
lastValue = values[firstRunEnd];
|
||||
size_t currentRunStart = secondRunEnd;
|
||||
while (currentRunStart < values.size()) {
|
||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||
break;
|
||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||
break;
|
||||
lastValue = values[currentRunStart];
|
||||
progressionEnd = currentRunEnd;
|
||||
currentRunStart = currentRunEnd;
|
||||
}
|
||||
}
|
||||
else {
|
||||
step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
compression.covered = 1;
|
||||
if (progressionEnd > firstRunEnd) {
|
||||
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||
if (progressionValueCount >= 3) {
|
||||
compression.kind = FlatCompression::Kind::Progression;
|
||||
compression.covered = progressionEnd - start;
|
||||
compression.progressionValueCount = progressionValueCount;
|
||||
compression.step = step;
|
||||
compression.lastValue = lastValue;
|
||||
return compression;
|
||||
}
|
||||
}
|
||||
|
||||
if (compression.repeatCount > 1) {
|
||||
compression.kind = FlatCompression::Kind::EqualRun;
|
||||
compression.covered = compression.repeatCount;
|
||||
return compression;
|
||||
}
|
||||
|
||||
return compression;
|
||||
};
|
||||
|
||||
auto findRepeatedSublist = [&](size_t start) {
|
||||
size_t bestLength = 0;
|
||||
size_t bestRepeatCount = 1;
|
||||
size_t remaining = values.size() - start;
|
||||
|
||||
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||
size_t repeatCount = 1;
|
||||
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||
while (start + (repeatCount + 1) * length <= values.size()
|
||||
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||
++repeatCount;
|
||||
}
|
||||
|
||||
if (repeatCount <= 1)
|
||||
continue;
|
||||
|
||||
size_t covered = length * repeatCount;
|
||||
size_t bestCovered = bestLength * bestRepeatCount;
|
||||
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||
bestLength = length;
|
||||
bestRepeatCount = repeatCount;
|
||||
}
|
||||
}
|
||||
|
||||
return std::pair(bestLength, bestRepeatCount);
|
||||
};
|
||||
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
|
||||
FlatCompression flat = computeFlatCompression(index);
|
||||
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
|
||||
printer << " x" << sublistRepeatCount;
|
||||
index += repeatedSublistCoverage;
|
||||
continue;
|
||||
}
|
||||
switch (flat.kind) {
|
||||
case FlatCompression::Kind::Progression:
|
||||
printer << flat.firstValue << " to " << flat.lastValue;
|
||||
if (flat.step != 1)
|
||||
printer << " by " << flat.step;
|
||||
if (flat.repeatCount > 1)
|
||||
printer << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::EqualRun:
|
||||
printer << flat.firstValue << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::Single:
|
||||
printer << flat.firstValue;
|
||||
index += flat.covered;
|
||||
break;
|
||||
}
|
||||
}
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
|
||||
|
||||
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
|
||||
}
|
||||
|
||||
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
|
||||
}
|
||||
|
||||
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
|
||||
onnx_mlir::compact_asm::printValueTupleRun(
|
||||
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
|
||||
onnx_mlir::compact_asm::printTypeTupleRun(
|
||||
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::UnresolvedOperand lastOperand;
|
||||
if (parser.parseOperand(lastOperand))
|
||||
return failure();
|
||||
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
operands.push_back(firstOperand);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
OpAsmParser::UnresolvedOperand firstOperand;
|
||||
if (parser.parseOperand(firstOperand))
|
||||
return failure();
|
||||
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||
Type firstType;
|
||||
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||
if (!firstTypeResult.has_value()) {
|
||||
if (allowEmpty)
|
||||
return success();
|
||||
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||
}
|
||||
if (failed(*firstTypeResult))
|
||||
return failure();
|
||||
|
||||
auto appendType = [&](Type type) -> ParseResult {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
types.push_back(type);
|
||||
return success();
|
||||
};
|
||||
|
||||
if (appendType(firstType))
|
||||
return failure();
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
Type nextType;
|
||||
if (parser.parseType(nextType) || appendType(nextType))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
@@ -567,90 +47,6 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||
if (block.getNumArguments() == 0) {
|
||||
printer << "() = ()";
|
||||
return;
|
||||
}
|
||||
|
||||
if (block.getNumArguments() == 1) {
|
||||
printer.printOperand(block.getArgument(0));
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
return;
|
||||
}
|
||||
|
||||
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::Argument firstArgument,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::Argument lastArgument;
|
||||
if (parser.parseArgument(lastArgument))
|
||||
return failure();
|
||||
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||
}
|
||||
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||
OpAsmParser::Argument argument;
|
||||
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
arguments.push_back(firstArgument);
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument))
|
||||
return failure();
|
||||
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||
}
|
||||
|
||||
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||
argument.type = inputType;
|
||||
}
|
||||
|
||||
static ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (succeeded(parser.parseOptionalRParen())) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
if (parser.parseRParen() || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -875,7 +271,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane);
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
@@ -892,7 +288,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
@@ -916,7 +312,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOrTupleOperandList(parser, weights))
|
||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
@@ -927,7 +323,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
|
||||
@@ -16,7 +16,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
||||
raise
|
||||
|
||||
|
||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output=True):
|
||||
selector = selectors.DefaultSelector()
|
||||
recent_output = bytearray()
|
||||
captured_output = bytearray()
|
||||
@@ -32,19 +32,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||
os.close(key.fileobj)
|
||||
continue
|
||||
|
||||
reporter._clear()
|
||||
os.write(1, data)
|
||||
reporter._render()
|
||||
if stream_output:
|
||||
reporter._clear()
|
||||
os.write(1, data)
|
||||
reporter._render()
|
||||
captured_output.extend(data)
|
||||
recent_output.extend(data)
|
||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||
if stream_output:
|
||||
recent_output.extend(data)
|
||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||
finally:
|
||||
selector.close()
|
||||
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||
error_output = captured_output if not stream_output else recent_output
|
||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output))
|
||||
return bytes(captured_output)
|
||||
|
||||
|
||||
@@ -62,6 +65,18 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False
|
||||
subprocess.run(cmd, cwd=cwd, check=True)
|
||||
return None
|
||||
|
||||
stream_output = bool(getattr(reporter, "verbose", False))
|
||||
if not stream_output:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert process.stdout is not None
|
||||
output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False)
|
||||
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||
|
||||
try:
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
except OSError:
|
||||
|
||||
+24
-14
@@ -27,6 +27,10 @@ def print_validation_error(reporter, rel, exc):
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
if exc.output:
|
||||
output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output)
|
||||
if output_text:
|
||||
print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True)
|
||||
else:
|
||||
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||
print("=" * 72, file=sys.stderr, flush=True)
|
||||
@@ -60,6 +64,8 @@ def main():
|
||||
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
|
||||
ap.add_argument("--clean", action="store_true",
|
||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||
ap.add_argument("--verbose", action="store_true",
|
||||
help="Print per-stage progress and subprocess logs for passing validations too.")
|
||||
a = ap.parse_args()
|
||||
|
||||
script_dir = Path(__file__).parent.resolve()
|
||||
@@ -101,7 +107,7 @@ def main():
|
||||
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
|
||||
total_timing_sum = 0.0
|
||||
timed_benchmark_count = 0
|
||||
reporter = ProgressReporter(len(onnx_files))
|
||||
reporter = ProgressReporter(len(onnx_files), verbose=a.verbose)
|
||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||
rel = onnx_path.relative_to(operations_dir)
|
||||
try:
|
||||
@@ -112,6 +118,7 @@ def main():
|
||||
reporter=reporter,
|
||||
model_index=index,
|
||||
model_total=len(onnx_files),
|
||||
verbose=a.verbose,
|
||||
)
|
||||
results[str(rel)] = result.passed
|
||||
if result.pim_pass_timings:
|
||||
@@ -134,22 +141,25 @@ def main():
|
||||
# Summary
|
||||
n_passed = sum(1 for passed in results.values() if passed)
|
||||
n_total = len(results)
|
||||
status_width = len("Result")
|
||||
path_width = max(len("Operation"), *(len(rel) for rel in results))
|
||||
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
|
||||
|
||||
print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL)
|
||||
print(separator)
|
||||
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
|
||||
print(separator)
|
||||
for rel, passed in results.items():
|
||||
plain_status = "PASS" if passed else "FAIL"
|
||||
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||
failing = [rel for rel, passed in results.items() if not passed]
|
||||
if a.verbose or failing:
|
||||
status_width = len("Result")
|
||||
path_width = max(len("Operation"), *(len(rel) for rel in results))
|
||||
separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+"
|
||||
print(separator)
|
||||
print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |")
|
||||
print(separator)
|
||||
for rel, passed in results.items():
|
||||
if not a.verbose and passed:
|
||||
continue
|
||||
plain_status = "PASS" if passed else "FAIL"
|
||||
status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \
|
||||
Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL
|
||||
print(f"| {rel.ljust(path_width)} | {status} |")
|
||||
print(separator)
|
||||
print_average_pim_pass_timings(
|
||||
pass_timing_sums,
|
||||
pass_timing_counts,
|
||||
|
||||
+17
-13
@@ -36,7 +36,7 @@ class ValidationResult:
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None):
|
||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None, verbose=False):
|
||||
self.total_models = total_models
|
||||
self.stages_per_model = stages_per_model
|
||||
self.total_steps = max(1, total_models * stages_per_model)
|
||||
@@ -45,6 +45,7 @@ class ProgressReporter:
|
||||
self.failed_models = 0
|
||||
self.current_label = ""
|
||||
self.enabled = sys.stdout.isatty() if enabled is None else enabled
|
||||
self.verbose = verbose
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
self.suspended = False
|
||||
|
||||
@@ -96,6 +97,8 @@ class ProgressReporter:
|
||||
sys.stdout.flush()
|
||||
|
||||
def log(self, message="", color=None):
|
||||
if not self.enabled and not self.verbose:
|
||||
return
|
||||
if self.enabled:
|
||||
self._clear()
|
||||
if color:
|
||||
@@ -228,7 +231,7 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor):
|
||||
return arrays
|
||||
|
||||
|
||||
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3):
|
||||
def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3, verbose=False):
|
||||
all_passed = True
|
||||
rows = []
|
||||
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
|
||||
@@ -245,26 +248,27 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
||||
result_width = len("Result")
|
||||
separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+"
|
||||
|
||||
print(separator)
|
||||
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
|
||||
print(separator)
|
||||
for name, diff_text, passed in rows:
|
||||
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
|
||||
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
|
||||
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
|
||||
print(separator)
|
||||
if verbose or not all_passed:
|
||||
print(separator)
|
||||
print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |")
|
||||
print(separator)
|
||||
for name, diff_text, passed in rows:
|
||||
status_text = ("PASS" if passed else "FAIL").ljust(result_width)
|
||||
status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL
|
||||
print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |")
|
||||
print(separator)
|
||||
return all_passed
|
||||
|
||||
|
||||
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
|
||||
reporter=None, model_index=1, model_total=1):
|
||||
reporter=None, model_index=1, model_total=1, verbose=False):
|
||||
network_onnx_path = Path(network_onnx_path).resolve()
|
||||
raptor_path = Path(raptor_path).resolve()
|
||||
onnx_include_dir = Path(onnx_include_dir).resolve()
|
||||
simulator_dir = Path(simulator_dir).resolve()
|
||||
owns_reporter = reporter is None
|
||||
reporter = reporter or ProgressReporter(model_total)
|
||||
reporter = reporter or ProgressReporter(model_total, verbose=verbose)
|
||||
|
||||
workspace_dir = network_onnx_path.parent
|
||||
clean_workspace_artifacts(workspace_dir, network_onnx_path.stem)
|
||||
@@ -331,7 +335,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs")
|
||||
sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor)
|
||||
reporter.suspend()
|
||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold, verbose=verbose)
|
||||
reporter.resume()
|
||||
reporter.advance()
|
||||
reporter.record_result(passed)
|
||||
|
||||
Reference in New Issue
Block a user