This commit is contained in:
@@ -0,0 +1,486 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
namespace {
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
|
||||
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
||||
printer << " " << keyword << " ";
|
||||
printCompressedIntegerList(printer, coreIds);
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
|
||||
if (failed(parser.parseOptionalKeyword(keyword)))
|
||||
return success();
|
||||
return parseCompressedIntegerList(parser, coreIds);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
||||
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
|
||||
printer << " -> ()";
|
||||
}
|
||||
|
||||
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int32_t laneCount = 0;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|
||||
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
if (parser.parseRegion(*body))
|
||||
return failure();
|
||||
|
||||
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|
||||
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|
||||
|| parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute("operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr(
|
||||
{static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getOutputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
OpAsmParser::UnresolvedOperand firstOutput;
|
||||
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
|
||||
if (firstOutputResult.has_value()) {
|
||||
if (failed(*firstOutputResult))
|
||||
return failure();
|
||||
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, outputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (outputs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
|
||||
|
||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimEmptyManyOp::print(OpAsmPrinter& printer) {
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " x" << getOutputs().size();
|
||||
}
|
||||
|
||||
ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
int64_t resultCount = 0;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)
|
||||
|| parser.parseKeyword("x") || parser.parseInteger(resultCount))
|
||||
return failure();
|
||||
|
||||
if (resultCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive");
|
||||
|
||||
SmallVector<Type> resultTypes(resultCount, outputType);
|
||||
result.addTypes(resultTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void PimSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!targetCoreIds.empty())
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutputBuffer().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
Type outputBufferType;
|
||||
Type outputType;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||
|| parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||
if (!sourceCoreIds.empty())
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
|
||||
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimExtractRowsOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printer << " into ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedValueSequence(printer, getOutputBuffers());
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
|
||||
Type inputType;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen()
|
||||
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (outputBuffers.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
|
||||
if (parser.resolveOperand(input, inputType, result.operands)
|
||||
|| parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis() << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printer << " into ";
|
||||
printer.printOperand(getOutputBuffer());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult PimConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t axis = 0;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||
SmallVector<Type> inputTypes;
|
||||
Type outputType;
|
||||
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis) || parseCompressedOperandSequence(parser, inputs)
|
||||
|| parser.parseKeyword("into") || parser.parseOperand(outputBuffer)
|
||||
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseLParen()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false) || parser.parseRParen()
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (result.attributes.get("axis"))
|
||||
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
|
||||
|
||||
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperand(outputBuffer, outputType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user