#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/LogicalResult.h" #include #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { namespace { using namespace onnx_mlir::compact_asm; static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); } static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI32ArrayAttr(values); } static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef arguments) { printer << "("; for (auto [index, argument] : llvm::enumerate(arguments)) { if (index != 0) printer << ", "; printer.printOperand(argument); } printer << ")"; } static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl& arguments) { if (parser.parseLParen()) return failure(); if (succeeded(parser.parseOptionalRParen())) return success(); OpAsmParser::Argument argument; if (parser.parseArgument(argument)) return failure(); arguments.push_back(argument); while (succeeded(parser.parseOptionalComma())) { if (parser.parseArgument(argument)) return failure(); arguments.push_back(argument); } return parser.parseRParen(); } static void applyBatchRegionArgumentTypes(ArrayRef inputTypes, ArrayRef weightTypes, ArrayRef outputTypes, OpAsmParser::Argument& laneArg, SmallVectorImpl& weightArgs, SmallVectorImpl& inputArgs, SmallVectorImpl& outputArgs, SmallVectorImpl& regionArgs, Builder& builder) { laneArg.type = builder.getIndexType(); regionArgs.push_back(laneArg); applyArgumentTypes(weightTypes, weightArgs); llvm::append_range(regionArgs, weightArgs); applyArgumentTypes(inputTypes, inputArgs); applyArgumentTypes(outputTypes, outputArgs); llvm::append_range(regionArgs, inputArgs); llvm::append_range(regionArgs, outputArgs); } static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { printCompressedValueList(printer, arguments, delimiter); printer << " = "; printCompressedValueList(printer, operands, delimiter); } static ParseResult parseBoundValueList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& arguments, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) { if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); return success(); } if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { switch (currentDelimiter) { case ListDelimiter::Paren: return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare(); } llvm_unreachable("unsupported delimiter"); }; if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) { return failure(); } return success(); } } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueSequence(printer, getOutputs()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printCompressedTypeSequence(printer, getOutputs().getTypes()); } ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector outputs; SmallVector outputTypes; OpAsmParser::UnresolvedOperand firstOutput; OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); if (firstOutputResult.has_value()) { if (failed(*firstOutputResult)) return failure(); if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedOperandEntry(parser, outputs)) return failure(); } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (outputs.size() != outputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); } void SpatExtractRowsOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); printer.printOptionalAttrDict((*this)->getAttrs()); printer << " : "; printer.printType(getInput().getType()); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); } ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { OpAsmParser::UnresolvedOperand input; Type inputType; SmallVector outputTypes; if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) return failure(); if (parser.resolveOperand(input, inputType, result.operands)) return failure(); result.addTypes(outputTypes); return success(); } void SpatConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis(); printer << " "; printCompressedValueSequence(printer, getInputs()); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printer.printType(getOutput().getType()); } ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { int64_t axis = 0; SmallVector inputs; SmallVector inputTypes; Type outputType; if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); if (parseCompressedOperandSequence(parser, inputs)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseType(outputType)) return failure(); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (result.attributes.get("axis")) return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputType); return success(); } void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; SmallVector weightArgs; weightArgs.reserve(getWeights().size()); for (unsigned index = 0; index < getWeights().size(); ++index) weightArgs.push_back(getWeightArgument(index)); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; SmallVector inputArgs; inputArgs.reserve(getInputs().size()); for (unsigned index = 0; index < getInputs().size(); ++index) inputArgs.push_back(getInputArgument(index)); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " coreId " << coreIdAttr.getInt(); printer.printOptionalAttrDict((*this)->getAttrs(), {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); printer << " "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector weightArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; int32_t coreId = 0; if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); SmallVector inputArgs; if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreId cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreId) result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); applyArgumentTypes(weightTypes, weightArgs); applyArgumentTypes(inputTypes, inputArgs); llvm::append_range(regionArgs, weightArgs); llvm::append_range(regionArgs, inputArgs); return parser.parseRegion(*body, regionArgs); } void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getLaneArgument()); printer << " = 0 to " << getLaneCount(); printer << " "; SmallVector weightArgs; weightArgs.reserve(getWeights().size()); for (unsigned index = 0; index < getWeights().size(); ++index) weightArgs.push_back(getWeightArgument(index)); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; SmallVector inputArgs; inputArgs.reserve(getInputs().size()); for (unsigned index = 0; index < getInputs().size(); ++index) inputArgs.push_back(getInputArgument(index)); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (getNumResults() != 0) { printer << " shared_outs"; SmallVector outputArgs; outputArgs.reserve(getNumResults()); for (unsigned index = 0; index < getNumResults(); ++index) outputArgs.push_back(getOutputArgument(index)); printBlockArgumentList(printer, outputArgs); } if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { printer << " coreIds "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); } printer.printOptionalAttrDict( (*this)->getAttrs(), {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; printCompressedTypeSequence(printer, getResultTypes()); printer << " "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { int64_t lowerBound = 0; int32_t laneCount = 0; OpAsmParser::Argument laneArg; SmallVector weightArgs; SmallVector inputArgs; SmallVector outputArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; SmallVector coreIds; if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) || parser.parseKeyword("to") || parser.parseInteger(laneCount)) return failure(); if (lowerBound != 0) return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound"); if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); if (succeeded(parser.parseOptionalKeyword("shared_outs"))) if (parseBlockArgumentList(parser, outputArgs)) return failure(); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (outputArgs.size() != outputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of shared output bindings and result types must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreIds cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) return failure(); result.addTypes(outputTypes); Region* body = result.addRegion(); applyBatchRegionArgumentTypes( inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder()); return parser.parseRegion(*body, regionArgs); } void SpatInParallelOp::print(OpAsmPrinter& printer) { printer << " "; printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); printer.printOptionalAttrDict((*this)->getAttrs()); } ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) { auto& builder = parser.getBuilder(); std::unique_ptr region = std::make_unique(); SmallVector regionArgs; if (parser.parseRegion(*region, regionArgs)) return failure(); if (region->empty()) OpBuilder(builder.getContext()).createBlock(region.get()); result.addRegion(std::move(region)); return parser.parseOptionalAttrDict(result.attributes); } } // namespace spatial } // namespace onnx_mlir