#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/LogicalResult.h" #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { namespace { enum class ListDelimiter { Square, Paren }; static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { if (delimiter == ListDelimiter::Square) return parser.parseLSquare(); return parser.parseLParen(); } static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { if (delimiter == ListDelimiter::Square) return parser.parseOptionalRSquare(); return parser.parseOptionalRParen(); } static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { printer << (delimiter == ListDelimiter::Square ? "[" : "("); } static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { printer << (delimiter == ListDelimiter::Square ? "]" : ")"); } template static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& entries, ParseEntryFn parseEntry) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { EntryT entry; if (parseEntry(entry)) return failure(); int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) entries.push_back(entry); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) break; if (parser.parseComma()) return failure(); } return success(); } template static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { if (parser.parseLSquare()) return failure(); if (succeeded(parser.parseOptionalRSquare())) return success(); while (true) { int64_t first = 0; if (parser.parseInteger(first)) return failure(); if (succeeded(parser.parseOptionalKeyword("to"))) { int64_t last = 0; if (parser.parseInteger(last) || last < first) return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); int64_t step = 1; if (succeeded(parser.parseOptionalKeyword("by"))) { if (parser.parseInteger(step) || step <= 0) return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); } int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } if ((last - first) % step != 0) return parser.emitError(parser.getCurrentLocation(), "range end must be reachable from start using the given step"); for (int64_t value = first; value <= last; value += step) for (int64_t index = 0; index < repeatCount; ++index) values.push_back(static_cast(value)); } else { int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) values.push_back(static_cast(first)); } if (succeeded(parser.parseOptionalRSquare())) break; if (parser.parseComma()) return failure(); } return success(); } template static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { for (size_t index = 0; index < entries.size();) { size_t runEnd = index + 1; while (runEnd < entries.size() && entries[runEnd] == entries[index]) ++runEnd; if (index != 0) printer << ", "; printEntry(entries[index]); size_t runLength = runEnd - index; if (runLength > 1) printer << " x" << runLength; index = runEnd; } } template static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { printer << "["; for (size_t index = 0; index < values.size();) { if (index != 0) printer << ", "; auto findEqualRunEnd = [&](size_t start) { size_t end = start + 1; while (end < values.size() && values[end] == values[start]) ++end; return end; }; size_t firstRunEnd = findEqualRunEnd(index); size_t repeatCount = firstRunEnd - index; size_t progressionEnd = firstRunEnd; int64_t step = 0; IntT lastValue = values[index]; if (firstRunEnd < values.size()) { size_t secondRunEnd = findEqualRunEnd(firstRunEnd); step = static_cast(values[firstRunEnd]) - static_cast(values[index]); if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { progressionEnd = secondRunEnd; lastValue = values[firstRunEnd]; size_t currentRunStart = secondRunEnd; while (currentRunStart < values.size()) { size_t currentRunEnd = findEqualRunEnd(currentRunStart); if (currentRunEnd - currentRunStart != repeatCount) break; if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) break; lastValue = values[currentRunStart]; progressionEnd = currentRunEnd; currentRunStart = currentRunEnd; } } else { step = 0; } } size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { printer << values[index] << " to " << lastValue; if (step != 1) printer << " by " << step; if (repeatCount > 1) printer << " x" << repeatCount; index = progressionEnd; continue; } if (repeatCount > 1) { printer << values[index] << " x" << repeatCount; index = firstRunEnd; continue; } printer << values[index]; index = firstRunEnd; } printer << "]"; } static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); for (size_t index = 0; index < values.size();) { size_t equalRunEnd = index + 1; while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) ++equalRunEnd; if (index != 0) printer << ", "; if (equalRunEnd - index > 1) { printer.printOperand(values[index]); printer << " x" << (equalRunEnd - index); index = equalRunEnd; continue; } size_t rangeEnd = index + 1; if (auto firstResult = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextResult = dyn_cast(values[rangeEnd]); if (!nextResult || nextResult.getOwner() != firstResult.getOwner() || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) break; ++rangeEnd; } } else if (auto firstArg = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextArg = dyn_cast(values[rangeEnd]); if (!nextArg || nextArg.getOwner() != firstArg.getOwner() || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) break; ++rangeEnd; } } printer.printOperand(values[index]); if (rangeEnd - index >= 3) { printer << " to "; printer.printOperand(values[rangeEnd - 1]); } else if (rangeEnd - index == 2) { printer << ", "; printer.printOperand(values[index + 1]); } index = rangeEnd; } printCloseDelimiter(printer, delimiter); } static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); printCloseDelimiter(printer, delimiter); } static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, OpAsmParser::UnresolvedOperand firstOperand, SmallVectorImpl& operands) { if (succeeded(parser.parseOptionalKeyword("to"))) { OpAsmParser::UnresolvedOperand lastOperand; if (parser.parseOperand(lastOperand)) return failure(); if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) operands.push_back({firstOperand.location, firstOperand.name, number}); } else { int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) operands.push_back(firstOperand); } return success(); } static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, SmallVectorImpl& operands) { OpAsmParser::UnresolvedOperand firstOperand; if (parser.parseOperand(firstOperand)) return failure(); return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); } static ParseResult parseCompressedOperandList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) break; if (parser.parseComma()) return failure(); } return success(); } static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, SmallVectorImpl& operands) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedOperandEntry(parser, operands)) return failure(); return success(); } static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { for (size_t index = 0; index < values.size();) { size_t equalRunEnd = index + 1; while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) ++equalRunEnd; if (index != 0) printer << ", "; if (equalRunEnd - index > 1) { printer.printOperand(values[index]); printer << " x" << (equalRunEnd - index); index = equalRunEnd; continue; } size_t rangeEnd = index + 1; if (auto firstResult = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextResult = dyn_cast(values[rangeEnd]); if (!nextResult || nextResult.getOwner() != firstResult.getOwner() || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) break; ++rangeEnd; } } else if (auto firstArg = dyn_cast(values[index])) { while (rangeEnd < values.size()) { auto nextArg = dyn_cast(values[rangeEnd]); if (!nextArg || nextArg.getOwner() != firstArg.getOwner() || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) break; ++rangeEnd; } } printer.printOperand(values[index]); if (rangeEnd - index >= 3) { printer << " to "; printer.printOperand(values[rangeEnd - 1]); } else if (rangeEnd - index == 2) { printer << ", "; printer.printOperand(values[index + 1]); } index = rangeEnd; } } static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); } static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { Type firstType; OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); if (!firstTypeResult.has_value()) { if (allowEmpty) return success(); return parser.emitError(parser.getCurrentLocation(), "expected type"); } if (failed(*firstTypeResult)) return failure(); auto appendType = [&](Type type) -> ParseResult { int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } for (int64_t index = 0; index < repeatCount; ++index) types.push_back(type); return success(); }; if (appendType(firstType)) return failure(); while (succeeded(parser.parseOptionalComma())) { Type nextType; if (parser.parseType(nextType) || appendType(nextType)) return failure(); } return success(); } static void printChannelMetadata(OpAsmPrinter& printer, ArrayRef channelIds, ArrayRef sourceCoreIds, ArrayRef targetCoreIds) { printer << " channels "; printCompressedIntegerList(printer, channelIds); printer << " from "; printCompressedIntegerList(printer, sourceCoreIds); printer << " to "; printCompressedIntegerList(printer, targetCoreIds); } static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI64ArrayAttr(values); } static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI32ArrayAttr(values); } static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } static void buildImplicitRegionArgs(OpAsmParser& parser, ArrayRef inputTypes, SmallVectorImpl& generatedNames, SmallVectorImpl& arguments) { generatedNames.reserve(inputTypes.size()); arguments.reserve(inputTypes.size()); for (auto [index, inputType] : llvm::enumerate(inputTypes)) { generatedNames.push_back("arg" + std::to_string(index + 1)); OpAsmParser::Argument arg; arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; arg.type = inputType; arguments.push_back(arg); } } } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueSequence(printer, getOutputs()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printCompressedTypeSequence(printer, getOutputs().getTypes()); } ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector outputs; SmallVector outputTypes; OpAsmParser::UnresolvedOperand firstOutput; OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); if (firstOutputResult.has_value()) { if (failed(*firstOutputResult)) return failure(); if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedOperandEntry(parser, outputs)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (outputs.size() != outputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); } void SpatExtractRowsOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printer.printType(getInput().getType()); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); } ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { OpAsmParser::UnresolvedOperand input; Type inputType; SmallVector outputTypes; if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) return failure(); if (parser.resolveOperand(input, inputType, result.operands)) return failure(); result.addTypes(outputTypes); return success(); } void SpatConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis(); printer << " args = "; printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printer.printType(getOutput().getType()); } ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { int64_t axis = 0; SmallVector inputs; SmallVector inputTypes; Type outputType; if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); if (succeeded(parser.parseOptionalKeyword("args"))) { if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) return failure(); } else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseType(outputType)) return failure(); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (result.attributes.get("axis")) return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputType); return success(); } void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueList(printer, getWeights(), ListDelimiter::Square); printer << " args = "; printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " core_id " << coreIdAttr.getInt(); printer.printOptionalAttrDict((*this)->getAttrs(), {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); printer << " "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector regionArgs; SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; int32_t coreId = 0; if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) return failure(); if (succeeded(parser.parseOptionalKeyword("args"))) { if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) return failure(); } else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { return failure(); } bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); if (hasCoreId && parser.parseInteger(coreId)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreId) result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); return parser.parseRegion(*body, regionArgs); } void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " lanes " << getLaneCount() << " "; printCompressedValueList(printer, getWeights(), ListDelimiter::Square); printer << " args = "; printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { printer << " core_ids "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); } printer.printOptionalAttrDict( (*this)->getAttrs(), {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); printer << " "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { int32_t laneCount = 0; SmallVector regionArgs; SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; SmallVector coreIds; if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) return failure(); if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) return failure(); if (succeeded(parser.parseOptionalKeyword("args"))) { if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) return failure(); } else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { return failure(); } bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); return parser.parseRegion(*body, regionArgs); } void SpatChannelSendManyOp::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueSequence(printer, getInputs()); printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( (*this)->getAttrs(), {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); printer << " : "; printCompressedTypeSequence(printer, TypeRange(getInputs())); } ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector inputs; SmallVector inputTypes; SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; if (parseCompressedOperandSequence(parser, inputs)) return failure(); bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); if (hasMetadata) { if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") || parseCompressedIntegerList(parser, targetCoreIds)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) return failure(); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (hasMetadata && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") || result.attributes.get("targetCoreIds"))) return parser.emitError(parser.getCurrentLocation(), "channel metadata cannot be specified both positionally and in attr-dict"); if (hasMetadata) { result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); } return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); } void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( (*this)->getAttrs(), {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); printer << " : "; printCompressedTypeSequence(printer, getResultTypes()); } ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector outputTypes; SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); if (hasMetadata) { if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") || parseCompressedIntegerList(parser, targetCoreIds)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) return failure(); if (hasMetadata && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") || result.attributes.get("targetCoreIds"))) return parser.emitError(parser.getCurrentLocation(), "channel metadata cannot be specified both positionally and in attr-dict"); if (hasMetadata) { result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); } result.addTypes(outputTypes); return success(); } void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( (*this)->getAttrs(), {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); printer << " : "; printer.printType(getInput().getType()); } ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { OpAsmParser::UnresolvedOperand input; Type inputType; SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; if (parser.parseOperand(input)) return failure(); bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); if (hasMetadata) { if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") || parseCompressedIntegerList(parser, targetCoreIds)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) return failure(); if (hasMetadata && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") || result.attributes.get("targetCoreIds"))) return parser.emitError(parser.getCurrentLocation(), "channel metadata cannot be specified both positionally and in attr-dict"); if (hasMetadata) { result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); } return parser.resolveOperand(input, inputType, result.operands); } void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( (*this)->getAttrs(), {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); printer << " : "; printer.printType(getOutput().getType()); } ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { Type outputType; SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); if (hasMetadata) { if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") || parseCompressedIntegerList(parser, targetCoreIds)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) return failure(); if (hasMetadata && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") || result.attributes.get("targetCoreIds"))) return parser.emitError(parser.getCurrentLocation(), "channel metadata cannot be specified both positionally and in attr-dict"); if (hasMetadata) { result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); } result.addTypes(outputType); return success(); } } // namespace spatial } // namespace onnx_mlir