273 lines
12 KiB
C++
273 lines
12 KiB
C++
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Value.h"
|
|
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace pim {
|
|
|
|
namespace {
|
|
using namespace onnx_mlir::compact_asm;
|
|
|
|
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
|
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
|
}
|
|
|
|
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|
return parser.getBuilder().getI32IntegerAttr(value);
|
|
}
|
|
|
|
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
|
}
|
|
|
|
static void
|
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
|
printCompressedValueList(printer, arguments, delimiter);
|
|
printer << " = ";
|
|
printCompressedValueList(printer, operands, delimiter);
|
|
}
|
|
|
|
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|
ListDelimiter delimiter,
|
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
|
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma()))
|
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
|
return failure();
|
|
|
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
|
switch (currentDelimiter) {
|
|
case ListDelimiter::Paren: return parser.parseRParen();
|
|
case ListDelimiter::Square: return parser.parseRSquare();
|
|
}
|
|
llvm_unreachable("unsupported delimiter");
|
|
};
|
|
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
|
printer << " " << keyword << " ";
|
|
printCompressedIntegerList(printer, coreIds);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void PimCoreOp::print(OpAsmPrinter& printer) {
|
|
SmallVector<Value> weightArgs;
|
|
weightArgs.reserve(getWeights().size());
|
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
|
weightArgs.push_back(getWeightArgument(index));
|
|
|
|
printer << " ";
|
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
|
printer << " coreId " << getCoreId();
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
|
printer << " -> () ";
|
|
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
SmallVector<Type> weightTypes;
|
|
int32_t coreId = 0;
|
|
|
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
return failure();
|
|
|
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
|
if (hasCoreId && parser.parseInteger(coreId))
|
|
return failure();
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
|
return failure();
|
|
|
|
if (weights.size() != weightTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
if (weightArgs.size() != weights.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
if (hasCoreId && result.attributes.get("coreId"))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"coreId cannot be specified both positionally and in attr-dict");
|
|
|
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
|
|
return failure();
|
|
|
|
if (hasCoreId)
|
|
result.addAttribute("coreId", getI32Attr(parser, coreId));
|
|
|
|
Region* body = result.addRegion();
|
|
applyArgumentTypes(weightTypes, weightArgs);
|
|
return parser.parseRegion(*body, weightArgs);
|
|
}
|
|
|
|
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
|
printer << " ";
|
|
printer.printOperand(getLaneArgument());
|
|
printer << " = 0 to " << getLaneCount() << " ";
|
|
|
|
SmallVector<Value> weightArgs;
|
|
weightArgs.reserve(getWeights().size());
|
|
for (unsigned index = 0; index < getWeights().size(); ++index)
|
|
weightArgs.push_back(getWeightArgument(index));
|
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
|
printer << " ";
|
|
SmallVector<Value> inputArgs;
|
|
inputArgs.reserve(getInputs().size());
|
|
for (unsigned index = 0; index < getInputs().size(); ++index)
|
|
inputArgs.push_back(getInputArgument(index));
|
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
|
|
|
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
|
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
|
|
|
printer.printOptionalAttrDict(
|
|
(*this)->getAttrs(),
|
|
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
|
printer << " ";
|
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
printer << " -> () ";
|
|
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
int64_t lowerBound = 0;
|
|
int32_t laneCount = 0;
|
|
OpAsmParser::Argument laneArg;
|
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
SmallVector<Type> weightTypes;
|
|
SmallVector<Type> inputTypes;
|
|
SmallVector<int32_t> coreIds;
|
|
|
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
|
return failure();
|
|
if (lowerBound != 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
|
|
|
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|
|
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
return failure();
|
|
|
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
|
return failure();
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
|
return failure();
|
|
|
|
if (weights.size() != weightTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
if (weightArgs.size() != weights.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
if (inputs.size() != inputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
if (inputArgs.size() != inputs.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
|
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"coreIds cannot be specified both positionally and in attr-dict");
|
|
|
|
auto& builder = parser.getBuilder();
|
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
|
result.addAttribute(
|
|
"operandSegmentSizes",
|
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
if (hasCoreIds)
|
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
|
|
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
|
return failure();
|
|
}
|
|
|
|
Region* body = result.addRegion();
|
|
laneArg.type = builder.getIndexType();
|
|
regionArgs.push_back(laneArg);
|
|
applyArgumentTypes(weightTypes, weightArgs);
|
|
llvm::append_range(regionArgs, weightArgs);
|
|
applyArgumentTypes(inputTypes, inputArgs);
|
|
llvm::append_range(regionArgs, inputArgs);
|
|
return parser.parseRegion(*body, regionArgs);
|
|
}
|
|
|
|
void PimConcatOp::print(OpAsmPrinter& printer) {
|
|
printer << " axis " << getAxis() << " ";
|
|
printCompressedValueSequence(printer, getInputs());
|
|
printer << " into ";
|
|
printer.printOperand(getOutputBuffer());
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
|
printer << " : ";
|
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
|
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
|
printer << " -> ";
|
|
printer.printType(getOutput().getType());
|
|
}
|
|
|
|
ParseResult PimConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
int64_t axis = 0;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
OpAsmParser::UnresolvedOperand outputBuffer;
|
|
SmallVector<Type> inputTypes;
|
|
Type outputType;
|
|
|
|
if (parser.parseKeyword("axis") || parser.parseInteger(axis) || parseCompressedOperandSequence(parser, inputs)
|
|
|| parser.parseKeyword("into") || parser.parseOperand(outputBuffer)
|
|
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseLParen()
|
|
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false) || parser.parseRParen()
|
|
|| parser.parseArrow() || parser.parseType(outputType))
|
|
return failure();
|
|
|
|
if (inputs.size() != inputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
if (result.attributes.get("axis"))
|
|
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
|
|
|
|
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
|
|
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)
|
|
|| parser.resolveOperand(outputBuffer, outputType, result.operands))
|
|
return failure();
|
|
result.addTypes(outputType);
|
|
return success();
|
|
}
|
|
|
|
} // namespace pim
|
|
} // namespace onnx_mlir
|