quieter validation scripts (with optional verbose flag)
This commit is contained in:
@@ -17,532 +17,12 @@ namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ListDelimiter {
|
||||
Square,
|
||||
Paren
|
||||
};
|
||||
|
||||
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
return onnx_mlir::compact_asm::parseOpenDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
onnx_mlir::compact_asm::printOpenDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
onnx_mlir::compact_asm::printCloseDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
template <typename EntryT, typename ParseEntryFn>
|
||||
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<EntryT>& entries,
|
||||
ParseEntryFn parseEntry) {
|
||||
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult
|
||||
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<IntT> subgroup;
|
||||
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(values, subgroup);
|
||||
}
|
||||
else {
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
}
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult
|
||||
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||
}
|
||||
|
||||
template <typename RangeT, typename PrintEntryFn>
|
||||
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||
for (size_t index = 0; index < entries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||
++runEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printEntry(entries[index]);
|
||||
size_t runLength = runEnd - index;
|
||||
if (runLength > 1)
|
||||
printer << " x" << runLength;
|
||||
index = runEnd;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||
struct FlatCompression {
|
||||
enum class Kind {
|
||||
Single,
|
||||
EqualRun,
|
||||
Progression
|
||||
};
|
||||
|
||||
Kind kind = Kind::Single;
|
||||
size_t covered = 1;
|
||||
size_t repeatCount = 1;
|
||||
size_t progressionValueCount = 1;
|
||||
int64_t step = 1;
|
||||
IntT firstValue {};
|
||||
IntT lastValue {};
|
||||
};
|
||||
|
||||
auto computeFlatCompression = [&](size_t start) {
|
||||
FlatCompression compression;
|
||||
compression.firstValue = values[start];
|
||||
compression.lastValue = values[start];
|
||||
|
||||
auto findEqualRunEnd = [&](size_t runStart) {
|
||||
size_t runEnd = runStart + 1;
|
||||
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||
++runEnd;
|
||||
return runEnd;
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(start);
|
||||
compression.repeatCount = firstRunEnd - start;
|
||||
size_t progressionEnd = firstRunEnd;
|
||||
int64_t step = 0;
|
||||
IntT lastValue = values[start];
|
||||
|
||||
if (firstRunEnd < values.size()) {
|
||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||
progressionEnd = secondRunEnd;
|
||||
lastValue = values[firstRunEnd];
|
||||
size_t currentRunStart = secondRunEnd;
|
||||
while (currentRunStart < values.size()) {
|
||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||
break;
|
||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||
break;
|
||||
lastValue = values[currentRunStart];
|
||||
progressionEnd = currentRunEnd;
|
||||
currentRunStart = currentRunEnd;
|
||||
}
|
||||
}
|
||||
else {
|
||||
step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
compression.covered = 1;
|
||||
if (progressionEnd > firstRunEnd) {
|
||||
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||
if (progressionValueCount >= 3) {
|
||||
compression.kind = FlatCompression::Kind::Progression;
|
||||
compression.covered = progressionEnd - start;
|
||||
compression.progressionValueCount = progressionValueCount;
|
||||
compression.step = step;
|
||||
compression.lastValue = lastValue;
|
||||
return compression;
|
||||
}
|
||||
}
|
||||
|
||||
if (compression.repeatCount > 1) {
|
||||
compression.kind = FlatCompression::Kind::EqualRun;
|
||||
compression.covered = compression.repeatCount;
|
||||
return compression;
|
||||
}
|
||||
|
||||
return compression;
|
||||
};
|
||||
|
||||
auto findRepeatedSublist = [&](size_t start) {
|
||||
size_t bestLength = 0;
|
||||
size_t bestRepeatCount = 1;
|
||||
size_t remaining = values.size() - start;
|
||||
|
||||
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||
size_t repeatCount = 1;
|
||||
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||
while (start + (repeatCount + 1) * length <= values.size()
|
||||
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||
++repeatCount;
|
||||
}
|
||||
|
||||
if (repeatCount <= 1)
|
||||
continue;
|
||||
|
||||
size_t covered = length * repeatCount;
|
||||
size_t bestCovered = bestLength * bestRepeatCount;
|
||||
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||
bestLength = length;
|
||||
bestRepeatCount = repeatCount;
|
||||
}
|
||||
}
|
||||
|
||||
return std::pair(bestLength, bestRepeatCount);
|
||||
};
|
||||
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
|
||||
FlatCompression flat = computeFlatCompression(index);
|
||||
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
|
||||
printer << " x" << sublistRepeatCount;
|
||||
index += repeatedSublistCoverage;
|
||||
continue;
|
||||
}
|
||||
switch (flat.kind) {
|
||||
case FlatCompression::Kind::Progression:
|
||||
printer << flat.firstValue << " to " << flat.lastValue;
|
||||
if (flat.step != 1)
|
||||
printer << " by " << flat.step;
|
||||
if (flat.repeatCount > 1)
|
||||
printer << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::EqualRun:
|
||||
printer << flat.firstValue << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::Single:
|
||||
printer << flat.firstValue;
|
||||
index += flat.covered;
|
||||
break;
|
||||
}
|
||||
}
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
|
||||
|
||||
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
|
||||
}
|
||||
|
||||
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
|
||||
}
|
||||
|
||||
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
|
||||
onnx_mlir::compact_asm::printValueTupleRun(
|
||||
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
|
||||
onnx_mlir::compact_asm::printTypeTupleRun(
|
||||
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::UnresolvedOperand lastOperand;
|
||||
if (parser.parseOperand(lastOperand))
|
||||
return failure();
|
||||
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
operands.push_back(firstOperand);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
OpAsmParser::UnresolvedOperand firstOperand;
|
||||
if (parser.parseOperand(firstOperand))
|
||||
return failure();
|
||||
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||
Type firstType;
|
||||
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||
if (!firstTypeResult.has_value()) {
|
||||
if (allowEmpty)
|
||||
return success();
|
||||
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||
}
|
||||
if (failed(*firstTypeResult))
|
||||
return failure();
|
||||
|
||||
auto appendType = [&](Type type) -> ParseResult {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
types.push_back(type);
|
||||
return success();
|
||||
};
|
||||
|
||||
if (appendType(firstType))
|
||||
return failure();
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
Type nextType;
|
||||
if (parser.parseType(nextType) || appendType(nextType))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
@@ -567,90 +47,6 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||
if (block.getNumArguments() == 0) {
|
||||
printer << "() = ()";
|
||||
return;
|
||||
}
|
||||
|
||||
if (block.getNumArguments() == 1) {
|
||||
printer.printOperand(block.getArgument(0));
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
return;
|
||||
}
|
||||
|
||||
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::Argument firstArgument,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::Argument lastArgument;
|
||||
if (parser.parseArgument(lastArgument))
|
||||
return failure();
|
||||
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||
}
|
||||
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||
OpAsmParser::Argument argument;
|
||||
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
arguments.push_back(firstArgument);
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument))
|
||||
return failure();
|
||||
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||
}
|
||||
|
||||
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||
argument.type = inputType;
|
||||
}
|
||||
|
||||
static ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (succeeded(parser.parseOptionalRParen())) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
if (parser.parseRParen() || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -875,7 +271,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane);
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
@@ -892,7 +288,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
@@ -916,7 +312,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOrTupleOperandList(parser, weights))
|
||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
@@ -927,7 +323,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
|
||||
Reference in New Issue
Block a user