compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user