@@ -32,6 +32,14 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) {
|
||||
StringRef value;
|
||||
if (parser.parseKeyword(&value))
|
||||
return failure();
|
||||
attr = parser.getBuilder().getStringAttr(value);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
@@ -466,6 +474,131 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatBlueprintOp::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> operands {getInput()};
|
||||
llvm::append_range(operands, getFragments());
|
||||
|
||||
printer << " fragments";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
printer << " layout " << getLogicalLayout();
|
||||
printer << " physical " << getPhysicalLayout();
|
||||
printer << " offsets ";
|
||||
printCompressedIntegerList(printer, getFragmentOffsets());
|
||||
printer << " sizes ";
|
||||
printCompressedIntegerList(printer, getFragmentSizes());
|
||||
printer << " map " << getIndexMap();
|
||||
if (std::optional<StringRef> mode = getMode())
|
||||
printer << " mode " << *mode;
|
||||
if (std::optional<ArrayRef<int64_t>> operandIndices = getFragmentOperandIndices()) {
|
||||
printer << " operandIndices ";
|
||||
printCompressedIntegerList(printer, *operandIndices);
|
||||
}
|
||||
if (std::optional<ArrayRef<int64_t>> sourceOffsets = getFragmentSourceOffsets()) {
|
||||
printer << " sourceOffsets ";
|
||||
printCompressedIntegerList(printer, *sourceOffsets);
|
||||
}
|
||||
if (std::optional<ArrayRef<int64_t>> strides = getFragmentStrides()) {
|
||||
printer << " strides ";
|
||||
printCompressedIntegerList(printer, *strides);
|
||||
}
|
||||
if (std::optional<StringRef> conflictPolicy = getConflictPolicy())
|
||||
printer << " conflict " << *conflictPolicy;
|
||||
if (std::optional<StringRef> coveragePolicy = getCoveragePolicy())
|
||||
printer << " coverage " << *coveragePolicy;
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getLogicalLayoutAttrName().getValue(),
|
||||
getPhysicalLayoutAttrName().getValue(),
|
||||
getFragmentOffsetsAttrName().getValue(),
|
||||
getFragmentSizesAttrName().getValue(),
|
||||
getIndexMapAttrName().getValue(),
|
||||
getModeAttrName().getValue(),
|
||||
getFragmentOperandIndicesAttrName().getValue(),
|
||||
getFragmentSourceOffsetsAttrName().getValue(),
|
||||
getFragmentStridesAttrName().getValue(),
|
||||
getConflictPolicyAttrName().getValue(),
|
||||
getCoveragePolicyAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
||||
SmallVector<Type> operandTypes;
|
||||
Type outputType;
|
||||
StringAttr logicalLayout;
|
||||
StringAttr physicalLayout;
|
||||
StringAttr indexMap;
|
||||
StringAttr mode;
|
||||
StringAttr conflictPolicy;
|
||||
StringAttr coveragePolicy;
|
||||
SmallVector<int64_t> fragmentOffsets;
|
||||
SmallVector<int64_t> fragmentSizes;
|
||||
SmallVector<int64_t> fragmentOperandIndices;
|
||||
SmallVector<int64_t> fragmentSourceOffsets;
|
||||
SmallVector<int64_t> fragmentStrides;
|
||||
|
||||
if (parser.parseKeyword("fragments")
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)
|
||||
|| parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout)
|
||||
|| parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout)
|
||||
|| parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets)
|
||||
|| parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes)
|
||||
|| parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("operandIndices"))
|
||||
&& parseCompressedIntegerList(parser, fragmentOperandIndices))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("sourceOffsets"))
|
||||
&& parseCompressedIntegerList(parser, fragmentSourceOffsets))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
if (operands.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand");
|
||||
if (operands.size() != operandTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("logicalLayout", logicalLayout);
|
||||
result.addAttribute("physicalLayout", physicalLayout);
|
||||
result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets));
|
||||
result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes));
|
||||
result.addAttribute("indexMap", indexMap);
|
||||
if (mode)
|
||||
result.addAttribute("mode", mode);
|
||||
if (!fragmentOperandIndices.empty())
|
||||
result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices));
|
||||
if (!fragmentSourceOffsets.empty())
|
||||
result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets));
|
||||
if (!fragmentStrides.empty())
|
||||
result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides));
|
||||
if (conflictPolicy)
|
||||
result.addAttribute("conflictPolicy", conflictPolicy);
|
||||
if (coveragePolicy)
|
||||
result.addAttribute("coveragePolicy", coveragePolicy);
|
||||
|
||||
if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||
|
||||
Reference in New Issue
Block a user