#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP #define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" namespace onnx_mlir { namespace compact_asm { using namespace mlir; enum class ListDelimiter { Square, Paren }; inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { if (delimiter == ListDelimiter::Square) return parser.parseLSquare(); return parser.parseLParen(); } inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { if (delimiter == ListDelimiter::Square) return parser.parseOptionalRSquare(); return parser.parseOptionalRParen(); } template inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) { stream << (delimiter == ListDelimiter::Square ? "[" : "("); } template inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) { stream << (delimiter == ListDelimiter::Square ? "]" : ")"); } template inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& entries, ParseEntryFn parseEntry) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { EntryT entry; if (parseEntry(entry)) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) entries.push_back(entry); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parser.parseComma()) return failure(); } } template inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { if (succeeded(parser.parseOptionalLParen())) { SmallVector subgroup; if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup)) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) llvm::append_range(values, subgroup); } else { int64_t first = 0; if (parser.parseInteger(first)) return failure(); if (succeeded(parser.parseOptionalKeyword("to"))) { int64_t last = 0; if (parser.parseInteger(last) || last < first) return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); int64_t step = 1; if (succeeded(parser.parseOptionalKeyword("by"))) { if (parser.parseInteger(step) || step <= 0) return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); } int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } if ((last - first) % step != 0) { return parser.emitError(parser.getCurrentLocation(), "range end must be reachable from start using the given step"); } for (int64_t value = first; value <= last; value += step) for (int64_t index = 0; index < repeatCount; ++index) values.push_back(static_cast(value)); } else { int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) values.push_back(static_cast(first)); } } if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parser.parseComma()) return failure(); } } template inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { if (parseOpenDelimiter(parser, delimiter)) return failure(); return parseCompressedIntegerEntries(parser, delimiter, values); } template inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { for (size_t index = 0; index < entries.size();) { size_t runEnd = index + 1; while (runEnd < entries.size() && entries[runEnd] == entries[index]) ++runEnd; if (index != 0) printer << ", "; printEntry(entries[index]); size_t runLength = runEnd - index; if (runLength > 1) printer << " x" << runLength; index = runEnd; } } template inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef values) { struct FlatCompression { enum class Kind { Single, EqualRun, Progression }; Kind kind = Kind::Single; size_t covered = 1; size_t repeatCount = 1; size_t progressionValueCount = 1; int64_t step = 1; IntT firstValue {}; IntT lastValue {}; }; auto computeFlatCompression = [&](size_t start) { FlatCompression compression; compression.firstValue = values[start]; compression.lastValue = values[start]; auto findEqualRunEnd = [&](size_t runStart) { size_t runEnd = runStart + 1; while (runEnd < values.size() && values[runEnd] == values[runStart]) ++runEnd; return runEnd; }; size_t firstRunEnd = findEqualRunEnd(start); compression.repeatCount = firstRunEnd - start; size_t progressionEnd = firstRunEnd; int64_t step = 0; IntT lastValue = values[start]; if (firstRunEnd < values.size()) { size_t secondRunEnd = findEqualRunEnd(firstRunEnd); step = static_cast(values[firstRunEnd]) - static_cast(values[start]); if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) { progressionEnd = secondRunEnd; lastValue = values[firstRunEnd]; size_t currentRunStart = secondRunEnd; while (currentRunStart < values.size()) { size_t currentRunEnd = findEqualRunEnd(currentRunStart); if (currentRunEnd - currentRunStart != compression.repeatCount) break; if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) break; lastValue = values[currentRunStart]; progressionEnd = currentRunEnd; currentRunStart = currentRunEnd; } } else { step = 0; } } compression.covered = 1; if (progressionEnd > firstRunEnd) { size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount; if (progressionValueCount >= 3) { compression.kind = FlatCompression::Kind::Progression; compression.covered = progressionEnd - start; compression.progressionValueCount = progressionValueCount; compression.step = step; compression.lastValue = lastValue; return compression; } } if (compression.repeatCount > 1) { compression.kind = FlatCompression::Kind::EqualRun; compression.covered = compression.repeatCount; return compression; } return compression; }; auto findRepeatedSublist = [&](size_t start) { size_t bestLength = 0; size_t bestRepeatCount = 1; size_t remaining = values.size() - start; for (size_t length = 2; length * 2 <= remaining; ++length) { size_t repeatCount = 1; ArrayRef candidate = values.slice(start, length); while (start + (repeatCount + 1) * length <= values.size() && llvm::equal(candidate, values.slice(start + repeatCount * length, length))) { ++repeatCount; } if (repeatCount <= 1) continue; size_t covered = length * repeatCount; size_t bestCovered = bestLength * bestRepeatCount; if (covered > bestCovered || (covered == bestCovered && length < bestLength)) { bestLength = length; bestRepeatCount = repeatCount; } } return std::pair(bestLength, bestRepeatCount); }; for (size_t index = 0; index < values.size();) { if (index != 0) stream << ", "; FlatCompression flat = computeFlatCompression(index); auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { printOpenDelimiter(stream, ListDelimiter::Paren); printCompressedIntegerEntries(stream, values.slice(index, sublistLength)); printCloseDelimiter(stream, ListDelimiter::Paren); stream << " x" << sublistRepeatCount; index += repeatedSublistCoverage; continue; } switch (flat.kind) { case FlatCompression::Kind::Progression: stream << flat.firstValue << " to " << flat.lastValue; if (flat.step != 1) stream << " by " << flat.step; if (flat.repeatCount > 1) stream << " x" << flat.repeatCount; index += flat.covered; break; case FlatCompression::Kind::EqualRun: stream << flat.firstValue << " x" << flat.repeatCount; index += flat.covered; break; case FlatCompression::Kind::Single: stream << flat.firstValue; index += flat.covered; break; } } } template inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef values, ListDelimiter delimiter) { printOpenDelimiter(stream, delimiter); printCompressedIntegerEntries(stream, values); printCloseDelimiter(stream, delimiter); } template inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values); } template inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { printCompressedIntegerSequence(printer, values, ListDelimiter::Square); } inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { for (size_t index = 0; index < values.size();) { size_t equalRunEnd = index + 1; while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) ++equalRunEnd; if (index != 0) printer << ", "; if (equalRunEnd - index > 1) { printer.printOperand(values[index]); printer << " x" << (equalRunEnd - index); index = equalRunEnd; continue; } size_t rangeEnd = index + 1; if (auto firstResult = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextResult = dyn_cast(values[rangeEnd]); if (!nextResult || nextResult.getOwner() != firstResult.getOwner() || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) { break; } ++rangeEnd; } } else if (auto firstArg = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextArg = dyn_cast(values[rangeEnd]); if (!nextArg || nextArg.getOwner() != firstArg.getOwner() || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) { break; } ++rangeEnd; } } printer.printOperand(values[index]); if (rangeEnd - index >= 3) { printer << " to "; printer.printOperand(values[rangeEnd - 1]); } else if (rangeEnd - index == 2) { printer << ", "; printer.printOperand(values[index + 1]); } index = rangeEnd; } } inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); } inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printCompressedValueSequence(printer, values); printCloseDelimiter(printer, delimiter); } inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printCompressedTypeSequence(printer, types); printCloseDelimiter(printer, delimiter); } inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { Type firstType; OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); if (!firstTypeResult.has_value()) { if (allowEmpty) return success(); return parser.emitError(parser.getCurrentLocation(), "expected type"); } if (failed(*firstTypeResult)) return failure(); auto appendType = [&](Type type) -> ParseResult { int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) types.push_back(type); return success(); }; if (appendType(firstType)) return failure(); while (succeeded(parser.parseOptionalComma())) { Type nextType; if (parser.parseType(nextType) || appendType(nextType)) return failure(); } return success(); } inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, OpAsmParser::UnresolvedOperand firstOperand, SmallVectorImpl& operands) { if (succeeded(parser.parseOptionalKeyword("to"))) { OpAsmParser::UnresolvedOperand lastOperand; if (parser.parseOperand(lastOperand)) return failure(); if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) operands.push_back({firstOperand.location, firstOperand.name, number}); return success(); } int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) operands.push_back(firstOperand); return success(); } inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, SmallVectorImpl& operands) { OpAsmParser::UnresolvedOperand firstOperand; if (parser.parseOperand(firstOperand)) return failure(); return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); } inline ParseResult parseCompressedOperandList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parser.parseComma()) return failure(); } } inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser, SmallVectorImpl& operands) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedOperandEntry(parser, operands)) return failure(); return success(); } inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& types) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false)) return failure(); return parseOptionalCloseDelimiter(parser, delimiter); } inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0) return false; SmallVector valueVec(values.begin(), values.end()); ArrayRef tuple(valueVec.data(), tupleSize); for (size_t index = tupleSize; index < values.size(); index += tupleSize) if (!llvm::equal(tuple, ArrayRef(valueVec).slice(index, tupleSize))) return false; return true; } inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0) return false; SmallVector typeVec(types.begin(), types.end()); ArrayRef tuple(typeVec.data(), tupleSize); for (size_t index = tupleSize; index < types.size(); index += tupleSize) if (!llvm::equal(tuple, ArrayRef(typeVec).slice(index, tupleSize))) return false; return true; } inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, ListDelimiter::Paren); for (size_t index = 0; index < tupleSize; ++index) { if (index != 0) printer << ", "; printer.printOperand(values[index]); } printCloseDelimiter(printer, ListDelimiter::Paren); printer << " x" << (values.size() / tupleSize); printCloseDelimiter(printer, delimiter); } inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, ListDelimiter::Paren); for (size_t index = 0; index < tupleSize; ++index) { if (index != 0) printer << ", "; printer.printType(types[index]); } printCloseDelimiter(printer, ListDelimiter::Paren); printer << " x" << (types.size() / tupleSize); printCloseDelimiter(printer, delimiter); } inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (succeeded(parser.parseOptionalLParen())) { SmallVector tupleOperands; if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) llvm::append_range(operands, tupleOperands); while (succeeded(parser.parseOptionalComma())) { if (parser.parseLParen()) return failure(); tupleOperands.clear(); if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) return failure(); repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) llvm::append_range(operands, tupleOperands); } return parseOptionalCloseDelimiter(parser, delimiter); } while (true) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parser.parseComma()) return failure(); } } inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& types) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (succeeded(parser.parseOptionalLParen())) { SmallVector tupleTypes; if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) llvm::append_range(types, tupleTypes); while (succeeded(parser.parseOptionalComma())) { if (parser.parseLParen()) return failure(); tupleTypes.clear(); if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) return failure(); repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) llvm::append_range(types, tupleTypes); } return parseOptionalCloseDelimiter(parser, delimiter); } while (true) { Type type; if (parser.parseType(type)) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t repeat = 0; repeat < repeatCount; ++repeat) types.push_back(type); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); if (parser.parseComma()) return failure(); } } inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) { if (block.getNumArguments() == 0) { printer << "() = ()"; return; } if (block.getNumArguments() == 1) { printer.printOperand(block.getArgument(0)); printer << " = "; printCompressedValueList(printer, operands, ListDelimiter::Paren); return; } printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren); printer << " = "; printCompressedValueList(printer, operands, ListDelimiter::Paren); } inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser, OpAsmParser::Argument firstArgument, SmallVectorImpl& arguments) { if (succeeded(parser.parseOptionalKeyword("to"))) { OpAsmParser::Argument lastArgument; if (parser.parseArgument(lastArgument)) return failure(); if (firstArgument.ssaName.name != lastArgument.ssaName.name || firstArgument.ssaName.number > lastArgument.ssaName.number) { return parser.emitError(parser.getCurrentLocation(), "invalid argument range"); } for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) { OpAsmParser::Argument argument; argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number}; arguments.push_back(argument); } return success(); } arguments.push_back(firstArgument); return success(); } inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser, SmallVectorImpl& arguments) { OpAsmParser::Argument firstArgument; if (parser.parseArgument(firstArgument)) return failure(); return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments); } inline void applyArgumentTypes(ArrayRef inputTypes, SmallVectorImpl& arguments) { for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes)) argument.type = inputType; } inline ParseResult parseArgumentBindings(OpAsmParser& parser, SmallVectorImpl& arguments, SmallVectorImpl& operands) { if (succeeded(parser.parseOptionalLParen())) { if (succeeded(parser.parseOptionalRParen())) { if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) return failure(); return success(); } OpAsmParser::Argument firstArgument; if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); if (parser.parseRParen() || parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) { return failure(); } return success(); } OpAsmParser::Argument argument; if (parser.parseArgument(argument) || parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) { return failure(); } arguments.push_back(argument); return success(); } } // namespace compact_asm } // namespace onnx_mlir #endif