#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/LogicalResult.h" #include #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { namespace { using namespace onnx_mlir::compact_asm; static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); } static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI32ArrayAttr(values); } static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) { StringRef value; if (parser.parseKeyword(&value)) return failure(); attr = parser.getBuilder().getStringAttr(value); return success(); } static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef arguments) { printer << "("; for (auto [index, argument] : llvm::enumerate(arguments)) { if (index != 0) printer << ", "; printer.printOperand(argument); } printer << ")"; } static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl& arguments) { if (parser.parseLParen()) return failure(); if (succeeded(parser.parseOptionalRParen())) return success(); OpAsmParser::Argument argument; if (parser.parseArgument(argument)) return failure(); arguments.push_back(argument); while (succeeded(parser.parseOptionalComma())) { if (parser.parseArgument(argument)) return failure(); arguments.push_back(argument); } return parser.parseRParen(); } static void applyBatchRegionArgumentTypes(ArrayRef inputTypes, ArrayRef weightTypes, ArrayRef outputTypes, OpAsmParser::Argument& laneArg, SmallVectorImpl& weightArgs, SmallVectorImpl& inputArgs, SmallVectorImpl& outputArgs, SmallVectorImpl& regionArgs, Builder& builder) { laneArg.type = builder.getIndexType(); regionArgs.push_back(laneArg); applyArgumentTypes(weightTypes, weightArgs); llvm::append_range(regionArgs, weightArgs); applyArgumentTypes(inputTypes, inputArgs); applyArgumentTypes(outputTypes, outputArgs); llvm::append_range(regionArgs, inputArgs); llvm::append_range(regionArgs, outputArgs); } static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { printCompressedValueList(printer, arguments, delimiter); printer << " = "; printCompressedValueList(printer, operands, delimiter); } static ParseResult parseBoundValueList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& arguments, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) { if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); return success(); } if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { switch (currentDelimiter) { case ListDelimiter::Paren: return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare(); } llvm_unreachable("unsupported delimiter"); }; if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); return success(); } template void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) { SmallVector weightArgs; weightArgs.reserve(op.getWeights().size()); for (unsigned index = 0; index < op.getWeights().size(); ++index) { auto weightArg = op.getWeightArgument(index); if (!weightArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); weightArgs.push_back(*weightArg); } SmallVector inputArgs; inputArgs.reserve(op.getInputs().size()); for (unsigned index = 0; index < op.getInputs().size(); ++index) { auto inputArg = op.getInputArgument(index); if (!inputArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); inputArgs.push_back(*inputArg); } printer << " "; printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren); if (auto coreIdAttr = op->template getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " coreId " << coreIdAttr.getInt(); printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size(); printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, op.getResultTypes()); printer << " "; printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } template ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) { SmallVector weightArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; int32_t crossbarWeightCount = 0; int32_t coreId = 0; if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); SmallVector inputArgs; if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) return failure(); (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreId cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreId) result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); applyArgumentTypes(weightTypes, weightArgs); applyArgumentTypes(inputTypes, inputArgs); llvm::append_range(regionArgs, weightArgs); llvm::append_range(regionArgs, inputArgs); return parser.parseRegion(*body, regionArgs); } template void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) { auto laneArg = op.getLaneArgument(); SmallVector weightArgs; weightArgs.reserve(op.getWeights().size()); for (unsigned index = 0; index < op.getWeights().size(); ++index) { auto weightArg = op.getWeightArgument(index); if (!weightArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); weightArgs.push_back(*weightArg); } SmallVector inputArgs; inputArgs.reserve(op.getInputs().size()); for (unsigned index = 0; index < op.getInputs().size(); ++index) { auto inputArg = op.getInputArgument(index); if (!inputArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); inputArgs.push_back(*inputArg); } SmallVector outputArgs; if (!laneArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); if (op.getNumResults() != 0) { outputArgs.reserve(op.getNumResults()); for (unsigned index = 0; index < op.getNumResults(); ++index) { auto outputArg = op.getOutputArgument(index); if (!outputArg) return printer.printGenericOp(op.getOperation(), /*printOpName=*/false); outputArgs.push_back(*outputArg); } } printer << " "; printer.printOperand(*laneArg); printer << " = 0 to " << op.getLaneCount(); printer << " "; printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren); if (op.getNumResults() != 0) { printer << " shared_outs"; printBlockArgumentList(printer, outputArgs); } printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size(); if (auto coreIdsAttr = op->template getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { printer << " coreIds "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); } printer.printOptionalAttrDict( op->getAttrs(), {op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, op.getResultTypes()); printer << " "; printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } template ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) { int64_t lowerBound = 0; int32_t laneCount = 0; OpAsmParser::Argument laneArg; SmallVector weightArgs; SmallVector inputArgs; SmallVector outputArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; int32_t crossbarWeightCount = 0; SmallVector coreIds; if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) || parser.parseKeyword("to") || parser.parseInteger(laneCount)) return failure(); if (lowerBound != 0) return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound"); if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); if (succeeded(parser.parseOptionalKeyword("shared_outs"))) if (parseBlockArgumentList(parser, outputArgs)) return failure(); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) return failure(); (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (outputArgs.size() != outputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of shared output bindings and result types must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreIds cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); applyBatchRegionArgumentTypes( inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder()); return parser.parseRegion(*body, regionArgs); } } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueSequence(printer, getOutputs()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printCompressedTypeSequence(printer, getOutputs().getTypes()); } ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector outputs; SmallVector outputTypes; OpAsmParser::UnresolvedOperand firstOutput; OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); if (firstOutputResult.has_value()) { if (failed(*firstOutputResult)) return failure(); if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedOperandEntry(parser, outputs)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (outputs.size() != outputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); } void SpatExtractRowsOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printer.printType(getInput().getType()); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); } ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { OpAsmParser::UnresolvedOperand input; Type inputType; SmallVector outputTypes; if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) return failure(); if (parser.resolveOperand(input, inputType, result.operands)) return failure(); result.addTypes(outputTypes); return success(); } void SpatConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis(); printer << " "; printCompressedValueSequence(printer, getInputs()); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printer.printType(getOutput().getType()); } ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { int64_t axis = 0; SmallVector inputs; SmallVector inputTypes; Type outputType; if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); if (parseCompressedOperandSequence(parser, inputs)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseType(outputType)) return failure(); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (result.attributes.get("axis")) return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputType); return success(); } void SpatBlueprintOp::print(OpAsmPrinter& printer) { SmallVector operands {getInput()}; llvm::append_range(operands, getFragments()); printer << " fragments"; printCompressedValueList(printer, operands, ListDelimiter::Paren); printer << " layout " << getLogicalLayout(); printer << " physical " << getPhysicalLayout(); printer << " offsets "; printCompressedIntegerList(printer, getFragmentOffsets()); printer << " sizes "; printCompressedIntegerList(printer, getFragmentSizes()); printer << " map " << getIndexMap(); if (std::optional mode = getMode()) printer << " mode " << *mode; if (std::optional> operandIndices = getFragmentOperandIndices()) { printer << " operandIndices "; printCompressedIntegerList(printer, *operandIndices); } if (std::optional> sourceOffsets = getFragmentSourceOffsets()) { printer << " sourceOffsets "; printCompressedIntegerList(printer, *sourceOffsets); } if (std::optional> strides = getFragmentStrides()) { printer << " strides "; printCompressedIntegerList(printer, *strides); } if (std::optional conflictPolicy = getConflictPolicy()) printer << " conflict " << *conflictPolicy; if (std::optional coveragePolicy = getCoveragePolicy()) printer << " coverage " << *coveragePolicy; printer.printOptionalAttrDict((*this)->getAttrs(), {getLogicalLayoutAttrName().getValue(), getPhysicalLayoutAttrName().getValue(), getFragmentOffsetsAttrName().getValue(), getFragmentSizesAttrName().getValue(), getIndexMapAttrName().getValue(), getModeAttrName().getValue(), getFragmentOperandIndicesAttrName().getValue(), getFragmentSourceOffsetsAttrName().getValue(), getFragmentStridesAttrName().getValue(), getConflictPolicyAttrName().getValue(), getCoveragePolicyAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren); printer << " -> "; printer.printType(getOutput().getType()); } ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector operands; SmallVector operandTypes; Type outputType; StringAttr logicalLayout; StringAttr physicalLayout; StringAttr indexMap; StringAttr mode; StringAttr conflictPolicy; StringAttr coveragePolicy; SmallVector fragmentOffsets; SmallVector fragmentSizes; SmallVector fragmentOperandIndices; SmallVector fragmentSourceOffsets; SmallVector fragmentStrides; if (parser.parseKeyword("fragments") || parseCompressedOperandList(parser, ListDelimiter::Paren, operands) || parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout) || parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout) || parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets) || parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes) || parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap)) return failure(); if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode)) return failure(); if (succeeded(parser.parseOptionalKeyword("operandIndices")) && parseCompressedIntegerList(parser, fragmentOperandIndices)) return failure(); if (succeeded(parser.parseOptionalKeyword("sourceOffsets")) && parseCompressedIntegerList(parser, fragmentSourceOffsets)) return failure(); if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides)) return failure(); if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy)) return failure(); if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseType(outputType)) return failure(); if (operands.empty()) return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand"); if (operands.size() != operandTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match"); auto& builder = parser.getBuilder(); result.addAttribute("logicalLayout", logicalLayout); result.addAttribute("physicalLayout", physicalLayout); result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets)); result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes)); result.addAttribute("indexMap", indexMap); if (mode) result.addAttribute("mode", mode); if (!fragmentOperandIndices.empty()) result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices)); if (!fragmentSourceOffsets.empty()) result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets)); if (!fragmentStrides.empty()) result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides)); if (conflictPolicy) result.addAttribute("conflictPolicy", conflictPolicy); if (coveragePolicy) result.addAttribute("coveragePolicy", coveragePolicy); if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputType); return success(); } void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); } ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) { return parseComputeLikeOp(parser, result); } void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); } ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) { return parseComputeLikeOp(parser, result); } void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); } ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) { return parseComputeBatchLikeOp(parser, result); } void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); } ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) { return parseComputeBatchLikeOp(parser, result); } void SpatInParallelOp::print(OpAsmPrinter& printer) { printer << " "; printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); printer.printOptionalAttrDict((*this)->getAttrs()); } ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) { auto& builder = parser.getBuilder(); std::unique_ptr region = std::make_unique(); SmallVector regionArgs; if (parser.parseRegion(*region, regionArgs)) return failure(); if (region->empty()) OpBuilder(builder.getContext()).createBlock(region.get()); result.addRegion(std::move(region)); return parser.parseOptionalAttrDict(result.attributes); } } // namespace spatial } // namespace onnx_mlir