better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
spat.map
This commit is contained in:
@@ -42,6 +42,10 @@ static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter)
|
||||
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||
}
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
template <typename EntryT, typename ParseEntryFn>
|
||||
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
@@ -75,51 +79,65 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
static ParseResult
|
||||
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<IntT> subgroup;
|
||||
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(values, subgroup);
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
@@ -128,6 +146,14 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult
|
||||
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||
}
|
||||
|
||||
template <typename RangeT, typename PrintEntryFn>
|
||||
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||
for (size_t index = 0; index < entries.size();) {
|
||||
@@ -146,35 +172,51 @@ static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printer << "[";
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
|
||||
auto findEqualRunEnd = [&](size_t start) {
|
||||
size_t end = start + 1;
|
||||
while (end < values.size() && values[end] == values[start])
|
||||
++end;
|
||||
return end;
|
||||
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||
struct FlatCompression {
|
||||
enum class Kind {
|
||||
Single,
|
||||
EqualRun,
|
||||
Progression
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(index);
|
||||
size_t repeatCount = firstRunEnd - index;
|
||||
Kind kind = Kind::Single;
|
||||
size_t covered = 1;
|
||||
size_t repeatCount = 1;
|
||||
size_t progressionValueCount = 1;
|
||||
int64_t step = 1;
|
||||
IntT firstValue {};
|
||||
IntT lastValue {};
|
||||
};
|
||||
|
||||
auto computeFlatCompression = [&](size_t start) {
|
||||
FlatCompression compression;
|
||||
compression.firstValue = values[start];
|
||||
compression.lastValue = values[start];
|
||||
|
||||
auto findEqualRunEnd = [&](size_t runStart) {
|
||||
size_t runEnd = runStart + 1;
|
||||
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||
++runEnd;
|
||||
return runEnd;
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(start);
|
||||
compression.repeatCount = firstRunEnd - start;
|
||||
size_t progressionEnd = firstRunEnd;
|
||||
int64_t step = 0;
|
||||
IntT lastValue = values[index];
|
||||
IntT lastValue = values[start];
|
||||
|
||||
if (firstRunEnd < values.size()) {
|
||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||
progressionEnd = secondRunEnd;
|
||||
lastValue = values[firstRunEnd];
|
||||
size_t currentRunStart = secondRunEnd;
|
||||
while (currentRunStart < values.size()) {
|
||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||
if (currentRunEnd - currentRunStart != repeatCount)
|
||||
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||
break;
|
||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||
break;
|
||||
@@ -188,27 +230,99 @@ static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> val
|
||||
}
|
||||
}
|
||||
|
||||
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
|
||||
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
|
||||
printer << values[index] << " to " << lastValue;
|
||||
if (step != 1)
|
||||
printer << " by " << step;
|
||||
if (repeatCount > 1)
|
||||
printer << " x" << repeatCount;
|
||||
index = progressionEnd;
|
||||
continue;
|
||||
compression.covered = 1;
|
||||
if (progressionEnd > firstRunEnd) {
|
||||
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||
if (progressionValueCount >= 3) {
|
||||
compression.kind = FlatCompression::Kind::Progression;
|
||||
compression.covered = progressionEnd - start;
|
||||
compression.progressionValueCount = progressionValueCount;
|
||||
compression.step = step;
|
||||
compression.lastValue = lastValue;
|
||||
return compression;
|
||||
}
|
||||
}
|
||||
|
||||
if (repeatCount > 1) {
|
||||
printer << values[index] << " x" << repeatCount;
|
||||
index = firstRunEnd;
|
||||
continue;
|
||||
if (compression.repeatCount > 1) {
|
||||
compression.kind = FlatCompression::Kind::EqualRun;
|
||||
compression.covered = compression.repeatCount;
|
||||
return compression;
|
||||
}
|
||||
|
||||
printer << values[index];
|
||||
index = firstRunEnd;
|
||||
return compression;
|
||||
};
|
||||
|
||||
auto findRepeatedSublist = [&](size_t start) {
|
||||
size_t bestLength = 0;
|
||||
size_t bestRepeatCount = 1;
|
||||
size_t remaining = values.size() - start;
|
||||
|
||||
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||
size_t repeatCount = 1;
|
||||
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||
while (start + (repeatCount + 1) * length <= values.size()
|
||||
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||
++repeatCount;
|
||||
}
|
||||
|
||||
if (repeatCount <= 1)
|
||||
continue;
|
||||
|
||||
size_t covered = length * repeatCount;
|
||||
size_t bestCovered = bestLength * bestRepeatCount;
|
||||
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||
bestLength = length;
|
||||
bestRepeatCount = repeatCount;
|
||||
}
|
||||
}
|
||||
|
||||
return std::pair(bestLength, bestRepeatCount);
|
||||
};
|
||||
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
|
||||
FlatCompression flat = computeFlatCompression(index);
|
||||
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
|
||||
printer << " x" << sublistRepeatCount;
|
||||
index += repeatedSublistCoverage;
|
||||
continue;
|
||||
}
|
||||
switch (flat.kind) {
|
||||
case FlatCompression::Kind::Progression:
|
||||
printer << flat.firstValue << " to " << flat.lastValue;
|
||||
if (flat.step != 1)
|
||||
printer << " by " << flat.step;
|
||||
if (flat.repeatCount > 1)
|
||||
printer << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::EqualRun:
|
||||
printer << flat.firstValue << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::Single:
|
||||
printer << flat.firstValue;
|
||||
index += flat.covered;
|
||||
break;
|
||||
}
|
||||
}
|
||||
printer << "]";
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||
@@ -267,6 +381,165 @@ static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
|
||||
|
||||
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
|
||||
printer << "[";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (values.size() / tupleSize) << "]";
|
||||
}
|
||||
|
||||
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
|
||||
printer << "[";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printType(types[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (types.size() / tupleSize) << "]";
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleOperands.clear();
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
}
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<Type> tupleTypes;
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleTypes.clear();
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
}
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
Type type;
|
||||
if (parser.parseType(type))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
types.push_back(type);
|
||||
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
@@ -440,19 +713,88 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static void buildImplicitRegionArgs(OpAsmParser& parser,
|
||||
ArrayRef<Type> inputTypes,
|
||||
SmallVectorImpl<std::string>& generatedNames,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
generatedNames.reserve(inputTypes.size());
|
||||
arguments.reserve(inputTypes.size());
|
||||
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
|
||||
generatedNames.push_back("arg" + std::to_string(index + 1));
|
||||
OpAsmParser::Argument arg;
|
||||
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
|
||||
arg.type = inputType;
|
||||
arguments.push_back(arg);
|
||||
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||
if (block.getNumArguments() == 0) {
|
||||
printer << "() = ()";
|
||||
return;
|
||||
}
|
||||
|
||||
if (block.getNumArguments() == 1) {
|
||||
printer.printOperand(block.getArgument(0));
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
return;
|
||||
}
|
||||
|
||||
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::Argument firstArgument,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::Argument lastArgument;
|
||||
if (parser.parseArgument(lastArgument))
|
||||
return failure();
|
||||
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||
}
|
||||
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||
OpAsmParser::Argument argument;
|
||||
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
arguments.push_back(firstArgument);
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument))
|
||||
return failure();
|
||||
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||
}
|
||||
|
||||
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||
argument.type = inputType;
|
||||
}
|
||||
|
||||
static ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (succeeded(parser.parseOptionalRParen())) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
if (parser.parseRParen() || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -519,8 +861,8 @@ ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result
|
||||
|
||||
void SpatConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis();
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
@@ -537,11 +879,7 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
if (parseCompressedOperandSequence(parser, inputs)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -563,14 +901,54 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " core_id " << coreIdAttr.getInt();
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
@@ -587,7 +965,6 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<std::string> generatedArgNames;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
@@ -598,15 +975,10 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
@@ -622,9 +994,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"core_id cannot be specified both positionally and in attr-dict");
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
@@ -639,27 +1013,34 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
printer << " core_ids ";
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
@@ -671,7 +1052,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int32_t laneCount = 0;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<std::string> generatedArgNames;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
@@ -682,24 +1062,18 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseCompressedOrTupleOperandList(parser, weights))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
@@ -709,8 +1083,11 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
@@ -718,7 +1095,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
@@ -726,7 +1103,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
@@ -867,6 +1244,55 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -908,5 +1334,47 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user