#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Value.h" #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace pim { namespace { using namespace onnx_mlir::compact_asm; static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI32ArrayAttr(values); } static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); } static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { printCompressedValueList(printer, arguments, delimiter); printer << " = "; printCompressedValueList(printer, operands, delimiter); } static ParseResult parseBoundValueList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& arguments, SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) { if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); return success(); } if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); while (succeeded(parser.parseOptionalComma())) if (parseOneCompressedArgumentEntry(parser, arguments)) return failure(); auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { switch (currentDelimiter) { case ListDelimiter::Paren: return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare(); } llvm_unreachable("unsupported delimiter"); }; if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); return success(); } static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef coreIds) { printer << " " << keyword << " "; printCompressedIntegerList(printer, coreIds); } } // namespace void PimCoreOp::print(OpAsmPrinter& printer) { SmallVector weightArgs; weightArgs.reserve(getWeights().size()); for (unsigned index = 0; index < getWeights().size(); ++index) weightArgs.push_back(getWeightArgument(index)); printer << " "; printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " coreId " << getCoreId(); printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " -> () "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) { SmallVector weightArgs; SmallVector weights; SmallVector weightTypes; int32_t coreId = 0; if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseLParen() || parser.parseRParen()) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (hasCoreId && result.attributes.get("coreId")) return parser.emitError(parser.getCurrentLocation(), "coreId cannot be specified both positionally and in attr-dict"); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)) return failure(); if (hasCoreId) result.addAttribute("coreId", getI32Attr(parser, coreId)); Region* body = result.addRegion(); applyArgumentTypes(weightTypes, weightArgs); return parser.parseRegion(*body, weightArgs); } void PimCoreBatchOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getLaneArgument()); printer << " = 0 to " << getLaneCount() << " "; SmallVector weightArgs; weightArgs.reserve(getWeights().size()); for (unsigned index = 0; index < getWeights().size(); ++index) weightArgs.push_back(getWeightArgument(index)); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; SmallVector inputArgs; inputArgs.reserve(getInputs().size()); for (unsigned index = 0; index < getInputs().size(); ++index) inputArgs.push_back(getInputArgument(index)); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef()); printer.printOptionalAttrDict( (*this)->getAttrs(), {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> () "; printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { int64_t lowerBound = 0; int32_t laneCount = 0; OpAsmParser::Argument laneArg; SmallVector weightArgs; SmallVector inputArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector coreIds; if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) || parser.parseKeyword("to") || parser.parseInteger(laneCount)) return failure(); if (lowerBound != 0) return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound"); if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights) || parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parser.parseLParen() || parser.parseRParen()) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (weightArgs.size() != weights.size()) return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreIds cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); result.addAttribute( "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) { return failure(); } Region* body = result.addRegion(); laneArg.type = builder.getIndexType(); regionArgs.push_back(laneArg); applyArgumentTypes(weightTypes, weightArgs); llvm::append_range(regionArgs, weightArgs); applyArgumentTypes(inputTypes, inputArgs); llvm::append_range(regionArgs, inputArgs); return parser.parseRegion(*body, regionArgs); } void PimConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis() << " "; printCompressedValueSequence(printer, getInputs()); printer << " into "; printer.printOperand(getOutputBuffer()); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printOpenDelimiter(printer, ListDelimiter::Paren); printCompressedTypeSequence(printer, TypeRange(getInputs())); printCloseDelimiter(printer, ListDelimiter::Paren); printer << " -> "; printer.printType(getOutput().getType()); } ParseResult PimConcatOp::parse(OpAsmParser& parser, OperationState& result) { int64_t axis = 0; SmallVector inputs; OpAsmParser::UnresolvedOperand outputBuffer; SmallVector inputTypes; Type outputType; if (parser.parseKeyword("axis") || parser.parseInteger(axis) || parseCompressedOperandSequence(parser, inputs) || parser.parseKeyword("into") || parser.parseOperand(outputBuffer) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseLParen() || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false) || parser.parseRParen() || parser.parseArrow() || parser.parseType(outputType)) return failure(); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); if (result.attributes.get("axis")) return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperand(outputBuffer, outputType, result.operands)) return failure(); result.addTypes(outputType); return success(); } } // namespace pim } // namespace onnx_mlir