80a7298552
Validate Operations / validate-operations (push) Has been cancelled
better reports (dcp merge and memory)
746 lines
27 KiB
C++
746 lines
27 KiB
C++
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
|
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
namespace onnx_mlir {
|
|
namespace compact_asm {
|
|
|
|
using namespace mlir;
|
|
|
|
enum class ListDelimiter {
|
|
Square,
|
|
Paren
|
|
};
|
|
|
|
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
|
if (delimiter == ListDelimiter::Square)
|
|
return parser.parseLSquare();
|
|
return parser.parseLParen();
|
|
}
|
|
|
|
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
|
if (delimiter == ListDelimiter::Square)
|
|
return parser.parseOptionalRSquare();
|
|
return parser.parseOptionalRParen();
|
|
}
|
|
|
|
template <typename StreamT>
|
|
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
|
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
|
|
}
|
|
|
|
template <typename StreamT>
|
|
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
|
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
|
|
}
|
|
|
|
template <typename EntryT, typename ParseEntryFn>
|
|
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
|
ListDelimiter delimiter,
|
|
SmallVectorImpl<EntryT>& entries,
|
|
ParseEntryFn parseEntry) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
while (true) {
|
|
EntryT entry;
|
|
if (parseEntry(entry))
|
|
return failure();
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t index = 0; index < repeatCount; ++index)
|
|
entries.push_back(entry);
|
|
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
if (parser.parseComma())
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
template <typename IntT>
|
|
inline ParseResult
|
|
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
while (true) {
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
SmallVector<IntT> subgroup;
|
|
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
|
return failure();
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
llvm::append_range(values, subgroup);
|
|
}
|
|
else {
|
|
int64_t first = 0;
|
|
if (parser.parseInteger(first))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
|
int64_t last = 0;
|
|
if (parser.parseInteger(last) || last < first)
|
|
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
|
|
|
int64_t step = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
|
if (parser.parseInteger(step) || step <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
|
}
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
if ((last - first) % step != 0) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"range end must be reachable from start using the given step");
|
|
}
|
|
|
|
for (int64_t value = first; value <= last; value += step)
|
|
for (int64_t index = 0; index < repeatCount; ++index)
|
|
values.push_back(static_cast<IntT>(value));
|
|
}
|
|
else {
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t index = 0; index < repeatCount; ++index)
|
|
values.push_back(static_cast<IntT>(first));
|
|
}
|
|
}
|
|
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
if (parser.parseComma())
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
template <typename IntT>
|
|
inline ParseResult
|
|
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
return parseCompressedIntegerEntries(parser, delimiter, values);
|
|
}
|
|
|
|
template <typename RangeT, typename PrintEntryFn>
|
|
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
|
for (size_t index = 0; index < entries.size();) {
|
|
size_t runEnd = index + 1;
|
|
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
|
++runEnd;
|
|
|
|
if (index != 0)
|
|
printer << ", ";
|
|
printEntry(entries[index]);
|
|
size_t runLength = runEnd - index;
|
|
if (runLength > 1)
|
|
printer << " x" << runLength;
|
|
index = runEnd;
|
|
}
|
|
}
|
|
|
|
template <typename StreamT, typename IntT>
|
|
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
|
|
struct FlatCompression {
|
|
enum class Kind {
|
|
Single,
|
|
EqualRun,
|
|
Progression
|
|
};
|
|
|
|
Kind kind = Kind::Single;
|
|
size_t covered = 1;
|
|
size_t repeatCount = 1;
|
|
size_t progressionValueCount = 1;
|
|
int64_t step = 1;
|
|
IntT firstValue {};
|
|
IntT lastValue {};
|
|
};
|
|
|
|
auto computeFlatCompression = [&](size_t start) {
|
|
FlatCompression compression;
|
|
compression.firstValue = values[start];
|
|
compression.lastValue = values[start];
|
|
|
|
auto findEqualRunEnd = [&](size_t runStart) {
|
|
size_t runEnd = runStart + 1;
|
|
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
|
++runEnd;
|
|
return runEnd;
|
|
};
|
|
|
|
size_t firstRunEnd = findEqualRunEnd(start);
|
|
compression.repeatCount = firstRunEnd - start;
|
|
size_t progressionEnd = firstRunEnd;
|
|
int64_t step = 0;
|
|
IntT lastValue = values[start];
|
|
|
|
if (firstRunEnd < values.size()) {
|
|
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
|
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
|
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
|
progressionEnd = secondRunEnd;
|
|
lastValue = values[firstRunEnd];
|
|
size_t currentRunStart = secondRunEnd;
|
|
while (currentRunStart < values.size()) {
|
|
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
|
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
|
break;
|
|
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
|
break;
|
|
lastValue = values[currentRunStart];
|
|
progressionEnd = currentRunEnd;
|
|
currentRunStart = currentRunEnd;
|
|
}
|
|
}
|
|
else {
|
|
step = 0;
|
|
}
|
|
}
|
|
|
|
compression.covered = 1;
|
|
if (progressionEnd > firstRunEnd) {
|
|
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
|
if (progressionValueCount >= 3) {
|
|
compression.kind = FlatCompression::Kind::Progression;
|
|
compression.covered = progressionEnd - start;
|
|
compression.progressionValueCount = progressionValueCount;
|
|
compression.step = step;
|
|
compression.lastValue = lastValue;
|
|
return compression;
|
|
}
|
|
}
|
|
|
|
if (compression.repeatCount > 1) {
|
|
compression.kind = FlatCompression::Kind::EqualRun;
|
|
compression.covered = compression.repeatCount;
|
|
return compression;
|
|
}
|
|
|
|
return compression;
|
|
};
|
|
|
|
auto findRepeatedSublist = [&](size_t start) {
|
|
size_t bestLength = 0;
|
|
size_t bestRepeatCount = 1;
|
|
size_t remaining = values.size() - start;
|
|
|
|
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
|
size_t repeatCount = 1;
|
|
ArrayRef<IntT> candidate = values.slice(start, length);
|
|
while (start + (repeatCount + 1) * length <= values.size()
|
|
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
|
++repeatCount;
|
|
}
|
|
|
|
if (repeatCount <= 1)
|
|
continue;
|
|
|
|
size_t covered = length * repeatCount;
|
|
size_t bestCovered = bestLength * bestRepeatCount;
|
|
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
|
bestLength = length;
|
|
bestRepeatCount = repeatCount;
|
|
}
|
|
}
|
|
|
|
return std::pair(bestLength, bestRepeatCount);
|
|
};
|
|
|
|
for (size_t index = 0; index < values.size();) {
|
|
if (index != 0)
|
|
stream << ", ";
|
|
|
|
FlatCompression flat = computeFlatCompression(index);
|
|
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
|
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
|
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
|
printOpenDelimiter(stream, ListDelimiter::Paren);
|
|
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
|
|
printCloseDelimiter(stream, ListDelimiter::Paren);
|
|
stream << " x" << sublistRepeatCount;
|
|
index += repeatedSublistCoverage;
|
|
continue;
|
|
}
|
|
|
|
switch (flat.kind) {
|
|
case FlatCompression::Kind::Progression:
|
|
stream << flat.firstValue << " to " << flat.lastValue;
|
|
if (flat.step != 1)
|
|
stream << " by " << flat.step;
|
|
if (flat.repeatCount > 1)
|
|
stream << " x" << flat.repeatCount;
|
|
index += flat.covered;
|
|
break;
|
|
case FlatCompression::Kind::EqualRun:
|
|
stream << flat.firstValue << " x" << flat.repeatCount;
|
|
index += flat.covered;
|
|
break;
|
|
case FlatCompression::Kind::Single:
|
|
stream << flat.firstValue;
|
|
index += flat.covered;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename StreamT, typename IntT>
|
|
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
|
printOpenDelimiter(stream, delimiter);
|
|
printCompressedIntegerEntries(stream, values);
|
|
printCloseDelimiter(stream, delimiter);
|
|
}
|
|
|
|
template <typename IntT>
|
|
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
|
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
|
}
|
|
|
|
template <typename IntT>
|
|
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
|
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
|
}
|
|
|
|
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
|
for (size_t index = 0; index < values.size();) {
|
|
size_t equalRunEnd = index + 1;
|
|
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
|
++equalRunEnd;
|
|
|
|
if (index != 0)
|
|
printer << ", ";
|
|
if (equalRunEnd - index > 1) {
|
|
printer.printOperand(values[index]);
|
|
printer << " x" << (equalRunEnd - index);
|
|
index = equalRunEnd;
|
|
continue;
|
|
}
|
|
|
|
size_t rangeEnd = index + 1;
|
|
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
|
while (rangeEnd < values.size()) {
|
|
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
|
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
|
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
|
|
break;
|
|
}
|
|
++rangeEnd;
|
|
}
|
|
}
|
|
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
|
while (rangeEnd < values.size()) {
|
|
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
|
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
|
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
|
|
break;
|
|
}
|
|
++rangeEnd;
|
|
}
|
|
}
|
|
|
|
printer.printOperand(values[index]);
|
|
if (rangeEnd - index >= 3) {
|
|
printer << " to ";
|
|
printer.printOperand(values[rangeEnd - 1]);
|
|
}
|
|
else if (rangeEnd - index == 2) {
|
|
printer << ", ";
|
|
printer.printOperand(values[index + 1]);
|
|
}
|
|
index = rangeEnd;
|
|
}
|
|
}
|
|
|
|
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
|
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
|
}
|
|
|
|
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
|
printOpenDelimiter(printer, delimiter);
|
|
printCompressedValueSequence(printer, values);
|
|
printCloseDelimiter(printer, delimiter);
|
|
}
|
|
|
|
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
|
printOpenDelimiter(printer, delimiter);
|
|
printCompressedTypeSequence(printer, types);
|
|
printCloseDelimiter(printer, delimiter);
|
|
}
|
|
|
|
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
|
Type firstType;
|
|
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
|
if (!firstTypeResult.has_value()) {
|
|
if (allowEmpty)
|
|
return success();
|
|
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
|
}
|
|
if (failed(*firstTypeResult))
|
|
return failure();
|
|
|
|
auto appendType = [&](Type type) -> ParseResult {
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t index = 0; index < repeatCount; ++index)
|
|
types.push_back(type);
|
|
return success();
|
|
};
|
|
|
|
if (appendType(firstType))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
Type nextType;
|
|
if (parser.parseType(nextType) || appendType(nextType))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
|
OpAsmParser::UnresolvedOperand firstOperand,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
|
OpAsmParser::UnresolvedOperand lastOperand;
|
|
if (parser.parseOperand(lastOperand))
|
|
return failure();
|
|
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
|
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
|
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
|
operands.push_back({firstOperand.location, firstOperand.name, number});
|
|
return success();
|
|
}
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t index = 0; index < repeatCount; ++index)
|
|
operands.push_back(firstOperand);
|
|
return success();
|
|
}
|
|
|
|
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
OpAsmParser::UnresolvedOperand firstOperand;
|
|
if (parser.parseOperand(firstOperand))
|
|
return failure();
|
|
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
|
}
|
|
|
|
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
|
ListDelimiter delimiter,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
while (true) {
|
|
if (parseOneCompressedOperandEntry(parser, operands))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
if (parser.parseComma())
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (parseOneCompressedOperandEntry(parser, operands))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma()))
|
|
if (parseOneCompressedOperandEntry(parser, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
|
|
return failure();
|
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
|
}
|
|
|
|
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
|
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
|
return false;
|
|
|
|
SmallVector<Value> valueVec(values.begin(), values.end());
|
|
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
|
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
|
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
|
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
|
return false;
|
|
|
|
SmallVector<Type> typeVec(types.begin(), types.end());
|
|
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
|
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
|
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
|
|
printOpenDelimiter(printer, delimiter);
|
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
for (size_t index = 0; index < tupleSize; ++index) {
|
|
if (index != 0)
|
|
printer << ", ";
|
|
printer.printOperand(values[index]);
|
|
}
|
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
printer << " x" << (values.size() / tupleSize);
|
|
printCloseDelimiter(printer, delimiter);
|
|
}
|
|
|
|
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
|
|
printOpenDelimiter(printer, delimiter);
|
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
for (size_t index = 0; index < tupleSize; ++index) {
|
|
if (index != 0)
|
|
printer << ", ";
|
|
printer.printType(types[index]);
|
|
}
|
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
printer << " x" << (types.size() / tupleSize);
|
|
printCloseDelimiter(printer, delimiter);
|
|
}
|
|
|
|
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
|
ListDelimiter delimiter,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
|
return failure();
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
llvm::append_range(operands, tupleOperands);
|
|
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
tupleOperands.clear();
|
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
|
return failure();
|
|
|
|
repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
llvm::append_range(operands, tupleOperands);
|
|
}
|
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
|
}
|
|
|
|
while (true) {
|
|
if (parseOneCompressedOperandEntry(parser, operands))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
if (parser.parseComma())
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
inline ParseResult
|
|
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
SmallVector<Type> tupleTypes;
|
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
|
return failure();
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
llvm::append_range(types, tupleTypes);
|
|
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
tupleTypes.clear();
|
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
|
return failure();
|
|
|
|
repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
llvm::append_range(types, tupleTypes);
|
|
}
|
|
return parseOptionalCloseDelimiter(parser, delimiter);
|
|
}
|
|
|
|
while (true) {
|
|
Type type;
|
|
if (parser.parseType(type))
|
|
return failure();
|
|
|
|
int64_t repeatCount = 1;
|
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
|
}
|
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
|
types.push_back(type);
|
|
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
|
return success();
|
|
if (parser.parseComma())
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
|
if (block.getNumArguments() == 0) {
|
|
printer << "() = ()";
|
|
return;
|
|
}
|
|
|
|
if (block.getNumArguments() == 1) {
|
|
printer.printOperand(block.getArgument(0));
|
|
printer << " = ";
|
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
|
return;
|
|
}
|
|
|
|
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
|
printer << " = ";
|
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
|
}
|
|
|
|
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
|
OpAsmParser::Argument firstArgument,
|
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
|
OpAsmParser::Argument lastArgument;
|
|
if (parser.parseArgument(lastArgument))
|
|
return failure();
|
|
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
|
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
|
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
|
}
|
|
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
|
OpAsmParser::Argument argument;
|
|
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
|
arguments.push_back(argument);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
arguments.push_back(firstArgument);
|
|
return success();
|
|
}
|
|
|
|
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
|
OpAsmParser::Argument firstArgument;
|
|
if (parser.parseArgument(firstArgument))
|
|
return failure();
|
|
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
|
}
|
|
|
|
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
|
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
|
argument.type = inputType;
|
|
}
|
|
|
|
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
|
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (succeeded(parser.parseOptionalRParen())) {
|
|
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
OpAsmParser::Argument firstArgument;
|
|
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma()))
|
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
|
return failure();
|
|
if (parser.parseRParen() || parser.parseEqual()
|
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
OpAsmParser::Argument argument;
|
|
if (parser.parseArgument(argument) || parser.parseEqual()
|
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
|
return failure();
|
|
}
|
|
arguments.push_back(argument);
|
|
return success();
|
|
}
|
|
|
|
} // namespace compact_asm
|
|
} // namespace onnx_mlir
|
|
|
|
#endif
|