compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s

This commit is contained in:
NiccoloN
2026-05-06 17:16:51 +02:00
parent 7bb58e80de
commit f2fe147961
13 changed files with 2264 additions and 307 deletions
+21 -167
View File
@@ -8,6 +8,7 @@
#include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -23,23 +24,23 @@ enum class ListDelimiter {
};
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
return onnx_mlir::compact_asm::parseOpenDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
onnx_mlir::compact_asm::printOpenDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
onnx_mlir::compact_asm::printCloseDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
@@ -51,31 +52,8 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
}
template <typename IntT>
@@ -388,156 +366,32 @@ static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
}
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
}
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize) << "]";
onnx_mlir::compact_asm::printValueTupleRun(
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize) << "]";
onnx_mlir::compact_asm::printTypeTupleRun(
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parser.parseRSquare();
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
}
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parser.parseRSquare();
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,