huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -7,8 +7,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs)) {
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
@@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
@@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user