This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -23,23 +24,23 @@ enum class ListDelimiter {
|
||||
};
|
||||
|
||||
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseLSquare();
|
||||
return parser.parseLParen();
|
||||
return onnx_mlir::compact_asm::parseOpenDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseOptionalRSquare();
|
||||
return parser.parseOptionalRParen();
|
||||
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
|
||||
onnx_mlir::compact_asm::printOpenDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||
onnx_mlir::compact_asm::printCloseDelimiter(
|
||||
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
|
||||
}
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
@@ -51,31 +52,8 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<EntryT>& entries,
|
||||
ParseEntryFn parseEntry) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
EntryT entry;
|
||||
if (parseEntry(entry))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
entries.push_back(entry);
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
|
||||
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
@@ -388,156 +366,32 @@ static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
|
||||
|
||||
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
|
||||
}
|
||||
|
||||
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
|
||||
}
|
||||
|
||||
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
|
||||
printer << "[";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (values.size() / tupleSize) << "]";
|
||||
onnx_mlir::compact_asm::printValueTupleRun(
|
||||
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
|
||||
printer << "[";
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printType(types[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (types.size() / tupleSize) << "]";
|
||||
onnx_mlir::compact_asm::printTypeTupleRun(
|
||||
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleOperands.clear();
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
}
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<Type> tupleTypes;
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleTypes.clear();
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
}
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
Type type;
|
||||
if (parser.parseType(type))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
types.push_back(type);
|
||||
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
|
||||
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
|
||||
Reference in New Issue
Block a user