640 lines
28 KiB
C++
640 lines
28 KiB
C++
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include <string>
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace spatial {
|
|
|
|
namespace {
|
|
using namespace onnx_mlir::compact_asm;
|
|
|
|
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
|
}
|
|
|
|
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
|
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
|
}
|
|
|
|
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|
return parser.getBuilder().getI32IntegerAttr(value);
|
|
}
|
|
|
|
static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) {
|
|
StringRef value;
|
|
if (parser.parseKeyword(&value))
|
|
return failure();
|
|
attr = parser.getBuilder().getStringAttr(value);
|
|
return success();
|
|
}
|
|
|
|
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
|
printer << "(";
|
|
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
|
if (index != 0)
|
|
printer << ", ";
|
|
printer.printOperand(argument);
|
|
}
|
|
printer << ")";
|
|
}
|
|
|
|
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalRParen()))
|
|
return success();
|
|
|
|
OpAsmParser::Argument argument;
|
|
if (parser.parseArgument(argument))
|
|
return failure();
|
|
arguments.push_back(argument);
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
if (parser.parseArgument(argument))
|
|
return failure();
|
|
arguments.push_back(argument);
|
|
}
|
|
return parser.parseRParen();
|
|
}
|
|
|
|
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
|
|
ArrayRef<Type> weightTypes,
|
|
ArrayRef<Type> outputTypes,
|
|
OpAsmParser::Argument& laneArg,
|
|
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
|
|
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
|
|
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
|
|
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
|
|
Builder& builder) {
|
|
laneArg.type = builder.getIndexType();
|
|
regionArgs.push_back(laneArg);
|
|
applyArgumentTypes(weightTypes, weightArgs);
|
|
llvm::append_range(regionArgs, weightArgs);
|
|
applyArgumentTypes(inputTypes, inputArgs);
|
|
applyArgumentTypes(outputTypes, outputArgs);
|
|
llvm::append_range(regionArgs, inputArgs);
|
|
llvm::append_range(regionArgs, outputArgs);
|
|
}
|
|
|
|
static void
|
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
|
printCompressedValueList(printer, arguments, delimiter);
|
|
printer << " = ";
|
|
printCompressedValueList(printer, operands, delimiter);
|
|
}
|
|
|
|
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|
ListDelimiter delimiter,
|
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
|
if (parseOpenDelimiter(parser, delimiter))
|
|
return failure();
|
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
|
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma()))
|
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
|
return failure();
|
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
|
switch (currentDelimiter) {
|
|
case ListDelimiter::Paren: return parser.parseRParen();
|
|
case ListDelimiter::Square: return parser.parseRSquare();
|
|
}
|
|
llvm_unreachable("unsupported delimiter");
|
|
};
|
|
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
|
|
SmallVector<Value> weightArgs;
|
|
weightArgs.reserve(op.getWeights().size());
|
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
|
auto weightArg = op.getWeightArgument(index);
|
|
if (!weightArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
weightArgs.push_back(*weightArg);
|
|
}
|
|
SmallVector<Value> inputArgs;
|
|
inputArgs.reserve(op.getInputs().size());
|
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
|
auto inputArg = op.getInputArgument(index);
|
|
if (!inputArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
inputArgs.push_back(*inputArg);
|
|
}
|
|
|
|
printer << " ";
|
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
|
printer << " ";
|
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
|
|
|
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
|
printer << " coreId " << coreIdAttr.getInt();
|
|
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
|
|
|
|
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
|
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
|
printer << " ";
|
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
|
printer << " -> ";
|
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
|
printer << " ";
|
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
|
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
SmallVector<Type> weightTypes;
|
|
SmallVector<Type> inputTypes;
|
|
SmallVector<Type> outputTypes;
|
|
int32_t crossbarWeightCount = 0;
|
|
int32_t coreId = 0;
|
|
|
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
return failure();
|
|
|
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
return failure();
|
|
|
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
|
if (hasCoreId && parser.parseInteger(coreId))
|
|
return failure();
|
|
|
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
|
return failure();
|
|
(void) crossbarWeightCount;
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
return failure();
|
|
|
|
if (weights.size() != weightTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
if (weightArgs.size() != weights.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
if (inputs.size() != inputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
if (inputArgs.size() != inputs.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"coreId cannot be specified both positionally and in attr-dict");
|
|
|
|
auto& builder = parser.getBuilder();
|
|
result.addAttribute(
|
|
"operandSegmentSizes",
|
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
if (hasCoreId)
|
|
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
|
|
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
return failure();
|
|
result.addTypes(outputTypes);
|
|
|
|
Region* body = result.addRegion();
|
|
applyArgumentTypes(weightTypes, weightArgs);
|
|
applyArgumentTypes(inputTypes, inputArgs);
|
|
llvm::append_range(regionArgs, weightArgs);
|
|
llvm::append_range(regionArgs, inputArgs);
|
|
return parser.parseRegion(*body, regionArgs);
|
|
}
|
|
|
|
template <typename ComputeBatchOpTy>
|
|
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
|
|
auto laneArg = op.getLaneArgument();
|
|
SmallVector<Value> weightArgs;
|
|
weightArgs.reserve(op.getWeights().size());
|
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
|
auto weightArg = op.getWeightArgument(index);
|
|
if (!weightArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
weightArgs.push_back(*weightArg);
|
|
}
|
|
SmallVector<Value> inputArgs;
|
|
inputArgs.reserve(op.getInputs().size());
|
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
|
auto inputArg = op.getInputArgument(index);
|
|
if (!inputArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
inputArgs.push_back(*inputArg);
|
|
}
|
|
|
|
SmallVector<BlockArgument> outputArgs;
|
|
if (!laneArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
if (op.getNumResults() != 0) {
|
|
outputArgs.reserve(op.getNumResults());
|
|
for (unsigned index = 0; index < op.getNumResults(); ++index) {
|
|
auto outputArg = op.getOutputArgument(index);
|
|
if (!outputArg)
|
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
|
outputArgs.push_back(*outputArg);
|
|
}
|
|
}
|
|
|
|
printer << " ";
|
|
printer.printOperand(*laneArg);
|
|
printer << " = 0 to " << op.getLaneCount();
|
|
printer << " ";
|
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
|
printer << " ";
|
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
|
if (op.getNumResults() != 0) {
|
|
printer << " shared_outs";
|
|
printBlockArgumentList(printer, outputArgs);
|
|
}
|
|
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
|
|
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
|
printer << " coreIds ";
|
|
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
|
}
|
|
printer.printOptionalAttrDict(
|
|
op->getAttrs(),
|
|
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
|
printer << " ";
|
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
|
printer << " -> ";
|
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
|
printer << " ";
|
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
|
}
|
|
|
|
template <typename ComputeBatchOpTy>
|
|
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
|
|
int64_t lowerBound = 0;
|
|
int32_t laneCount = 0;
|
|
OpAsmParser::Argument laneArg;
|
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
SmallVector<OpAsmParser::Argument> outputArgs;
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
SmallVector<Type> weightTypes;
|
|
SmallVector<Type> inputTypes;
|
|
SmallVector<Type> outputTypes;
|
|
int32_t crossbarWeightCount = 0;
|
|
SmallVector<int32_t> coreIds;
|
|
|
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
|
return failure();
|
|
if (lowerBound != 0)
|
|
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
return failure();
|
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
|
if (parseBlockArgumentList(parser, outputArgs))
|
|
return failure();
|
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
|
return failure();
|
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
|
return failure();
|
|
(void) crossbarWeightCount;
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
return failure();
|
|
|
|
if (weights.size() != weightTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
if (weightArgs.size() != weights.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
if (inputs.size() != inputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
if (inputArgs.size() != inputs.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
if (outputArgs.size() != outputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"number of shared output bindings and result types must match");
|
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"coreIds cannot be specified both positionally and in attr-dict");
|
|
|
|
auto& builder = parser.getBuilder();
|
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
|
result.addAttribute(
|
|
"operandSegmentSizes",
|
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
if (hasCoreIds)
|
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
|
|
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
return failure();
|
|
result.addTypes(outputTypes);
|
|
|
|
Region* body = result.addRegion();
|
|
applyBatchRegionArgumentTypes(
|
|
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
|
return parser.parseRegion(*body, regionArgs);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
|
printer << " ";
|
|
printCompressedValueSequence(printer, getOutputs());
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
printer << " : ";
|
|
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
|
}
|
|
|
|
ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
|
|
SmallVector<Type> outputTypes;
|
|
|
|
OpAsmParser::UnresolvedOperand firstOutput;
|
|
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
|
|
if (firstOutputResult.has_value()) {
|
|
if (failed(*firstOutputResult))
|
|
return failure();
|
|
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
|
|
return failure();
|
|
while (succeeded(parser.parseOptionalComma()))
|
|
if (parseOneCompressedOperandEntry(parser, outputs))
|
|
return failure();
|
|
}
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
return failure();
|
|
|
|
if (outputs.size() != outputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
|
|
|
|
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
|
}
|
|
|
|
void SpatExtractRowsOp::print(OpAsmPrinter& printer) {
|
|
printer << " ";
|
|
printer.printOperand(getInput());
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
printer << " : ";
|
|
printer.printType(getInput().getType());
|
|
printer << " -> ";
|
|
printCompressedTypeSequence(printer, getResultTypes());
|
|
}
|
|
|
|
ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
OpAsmParser::UnresolvedOperand input;
|
|
Type inputType;
|
|
SmallVector<Type> outputTypes;
|
|
|
|
if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parser.parseType(inputType) || parser.parseArrow()
|
|
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
|
return failure();
|
|
|
|
if (parser.resolveOperand(input, inputType, result.operands))
|
|
return failure();
|
|
result.addTypes(outputTypes);
|
|
return success();
|
|
}
|
|
|
|
void SpatConcatOp::print(OpAsmPrinter& printer) {
|
|
printer << " axis " << getAxis();
|
|
printer << " ";
|
|
printCompressedValueSequence(printer, getInputs());
|
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
printer << " -> ";
|
|
printer.printType(getOutput().getType());
|
|
}
|
|
|
|
ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
int64_t axis = 0;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
SmallVector<Type> inputTypes;
|
|
Type outputType;
|
|
|
|
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
|
return failure();
|
|
|
|
if (parseCompressedOperandSequence(parser, inputs))
|
|
return failure();
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parser.parseType(outputType))
|
|
return failure();
|
|
|
|
if (inputs.size() != inputTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
if (result.attributes.get("axis"))
|
|
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
|
|
|
|
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
|
|
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
return failure();
|
|
result.addTypes(outputType);
|
|
return success();
|
|
}
|
|
|
|
void SpatBlueprintOp::print(OpAsmPrinter& printer) {
|
|
SmallVector<Value> operands {getInput()};
|
|
llvm::append_range(operands, getFragments());
|
|
|
|
printer << " fragments";
|
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
|
printer << " layout " << getLogicalLayout();
|
|
printer << " physical " << getPhysicalLayout();
|
|
printer << " offsets ";
|
|
printCompressedIntegerList(printer, getFragmentOffsets());
|
|
printer << " sizes ";
|
|
printCompressedIntegerList(printer, getFragmentSizes());
|
|
printer << " map " << getIndexMap();
|
|
if (std::optional<StringRef> mode = getMode())
|
|
printer << " mode " << *mode;
|
|
if (std::optional<ArrayRef<int64_t>> operandIndices = getFragmentOperandIndices()) {
|
|
printer << " operandIndices ";
|
|
printCompressedIntegerList(printer, *operandIndices);
|
|
}
|
|
if (std::optional<ArrayRef<int64_t>> sourceOffsets = getFragmentSourceOffsets()) {
|
|
printer << " sourceOffsets ";
|
|
printCompressedIntegerList(printer, *sourceOffsets);
|
|
}
|
|
if (std::optional<ArrayRef<int64_t>> strides = getFragmentStrides()) {
|
|
printer << " strides ";
|
|
printCompressedIntegerList(printer, *strides);
|
|
}
|
|
if (std::optional<StringRef> conflictPolicy = getConflictPolicy())
|
|
printer << " conflict " << *conflictPolicy;
|
|
if (std::optional<StringRef> coveragePolicy = getCoveragePolicy())
|
|
printer << " coverage " << *coveragePolicy;
|
|
|
|
printer.printOptionalAttrDict((*this)->getAttrs(),
|
|
{getLogicalLayoutAttrName().getValue(),
|
|
getPhysicalLayoutAttrName().getValue(),
|
|
getFragmentOffsetsAttrName().getValue(),
|
|
getFragmentSizesAttrName().getValue(),
|
|
getIndexMapAttrName().getValue(),
|
|
getModeAttrName().getValue(),
|
|
getFragmentOperandIndicesAttrName().getValue(),
|
|
getFragmentSourceOffsetsAttrName().getValue(),
|
|
getFragmentStridesAttrName().getValue(),
|
|
getConflictPolicyAttrName().getValue(),
|
|
getCoveragePolicyAttrName().getValue()});
|
|
printer << " : ";
|
|
printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren);
|
|
printer << " -> ";
|
|
printer.printType(getOutput().getType());
|
|
}
|
|
|
|
ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
|
SmallVector<Type> operandTypes;
|
|
Type outputType;
|
|
StringAttr logicalLayout;
|
|
StringAttr physicalLayout;
|
|
StringAttr indexMap;
|
|
StringAttr mode;
|
|
StringAttr conflictPolicy;
|
|
StringAttr coveragePolicy;
|
|
SmallVector<int64_t> fragmentOffsets;
|
|
SmallVector<int64_t> fragmentSizes;
|
|
SmallVector<int64_t> fragmentOperandIndices;
|
|
SmallVector<int64_t> fragmentSourceOffsets;
|
|
SmallVector<int64_t> fragmentStrides;
|
|
|
|
if (parser.parseKeyword("fragments")
|
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)
|
|
|| parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout)
|
|
|| parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout)
|
|
|| parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets)
|
|
|| parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes)
|
|
|| parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("operandIndices"))
|
|
&& parseCompressedIntegerList(parser, fragmentOperandIndices))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("sourceOffsets"))
|
|
&& parseCompressedIntegerList(parser, fragmentSourceOffsets))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy))
|
|
return failure();
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
|| parseCompressedRepeatedList(
|
|
parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); })
|
|
|| parser.parseArrow() || parser.parseType(outputType))
|
|
return failure();
|
|
if (operands.empty())
|
|
return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand");
|
|
if (operands.size() != operandTypes.size())
|
|
return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match");
|
|
|
|
auto& builder = parser.getBuilder();
|
|
result.addAttribute("logicalLayout", logicalLayout);
|
|
result.addAttribute("physicalLayout", physicalLayout);
|
|
result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets));
|
|
result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes));
|
|
result.addAttribute("indexMap", indexMap);
|
|
if (mode)
|
|
result.addAttribute("mode", mode);
|
|
if (!fragmentOperandIndices.empty())
|
|
result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices));
|
|
if (!fragmentSourceOffsets.empty())
|
|
result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets));
|
|
if (!fragmentStrides.empty())
|
|
result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides));
|
|
if (conflictPolicy)
|
|
result.addAttribute("conflictPolicy", conflictPolicy);
|
|
if (coveragePolicy)
|
|
result.addAttribute("coveragePolicy", coveragePolicy);
|
|
|
|
if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands))
|
|
return failure();
|
|
result.addTypes(outputType);
|
|
return success();
|
|
}
|
|
|
|
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
|
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
|
}
|
|
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
|
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
|
|
}
|
|
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
|
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
|
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
|
|
}
|
|
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
|
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
|
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
|
|
}
|
|
|
|
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
|
printer << " ";
|
|
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
|
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
|
}
|
|
|
|
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
|
|
auto& builder = parser.getBuilder();
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
|
if (parser.parseRegion(*region, regionArgs))
|
|
return failure();
|
|
|
|
if (region->empty())
|
|
OpBuilder(builder.getContext()).createBlock(region.get());
|
|
result.addRegion(std::move(region));
|
|
return parser.parseOptionalAttrDict(result.attributes);
|
|
}
|
|
|
|
} // namespace spatial
|
|
} // namespace onnx_mlir
|