compact pim IR
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m15s

This commit is contained in:
NiccoloN
2026-05-06 17:16:51 +02:00
parent 7bb58e80de
commit f2fe147961
13 changed files with 2264 additions and 307 deletions

View File

@@ -0,0 +1,755 @@
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
namespace compact_asm {
using namespace mlir;
enum class ListDelimiter {
Square,
Paren
};
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
inline void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
}
inline void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0) {
return parser.emitError(
parser.getCurrentLocation(), "range end must be reachable from start using the given step");
}
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename IntT>
inline void printCompressedIntegerSequence(OpAsmPrinter& printer,
ArrayRef<IntT> values,
ListDelimiter delimiter) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
printOpenDelimiter(printer, delimiter);
for (size_t index = 0; index < values.size();) {
if (index != 0)
printer << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
printer << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
printer << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
printer << " by " << flat.step;
if (flat.repeatCount > 1)
printer << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
printer << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
printer << flat.firstValue;
index += flat.covered;
break;
}
}
printCloseDelimiter(printer, delimiter);
}
template <typename IntT>
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedValueSequence(printer, values);
printCloseDelimiter(printer, delimiter);
}
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedTypeSequence(printer, types);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser,
SmallVectorImpl<Type>& types,
bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
inline ParseResult parseCompressedOperandEntryWithFirst(
OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
return success();
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
return success();
}
inline ParseResult parseOneCompressedOperandEntry(
OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOperandSequence(
OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
inline ParseResult parseCompressedTypeList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
return failure();
return parseOptionalCloseDelimiter(parser, delimiter);
}
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
inline void printValueTupleRun(OpAsmPrinter& printer,
ValueRange values,
size_t tupleSize,
ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline void printTypeTupleRun(OpAsmPrinter& printer,
TypeRange types,
size_t tupleSize,
ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedOrTupleOperandList(
OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
arguments.push_back(argument);
return success();
}
} // namespace compact_asm
} // namespace onnx_mlir
#endif

View File

@@ -37,6 +37,11 @@ using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
@@ -382,10 +387,75 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const {
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds()))
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
}
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
size_t elementSize = inputType.getElementTypeBitWidth() / 8;
size_t rowSizeInBytes = static_cast<size_t>(inputType.getDimSize(1)) * elementSize;
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
emitMemCopyOp("lmv",
addressOf(outputBuffer, knowledge),
0,
inputAddr,
rowIndex * rowSizeInBytes,
rowSizeInBytes,
"len");
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
@@ -396,11 +466,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
@@ -682,6 +747,25 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto sendManyBatchOp = dyn_cast<pim::PimSendManyBatchOp>(op)) {
SmallVector<int32_t> laneTargetCoreIds;
laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, sendManyBatchOp.getInputs().size()))
laneTargetCoreIds.push_back(
sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (mlir::Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyOp::create(builder,
sendManyBatchOp.getLoc(),
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
ValueRange(mappedInputs));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
@@ -694,6 +778,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
continue;
}
if (auto receiveManyBatchOp = dyn_cast<pim::PimReceiveManyBatchOp>(op)) {
SmallVector<int32_t> laneSourceCoreIds;
laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size());
for (auto valueIndex : llvm::seq<size_t>(0, receiveManyBatchOp.getOutputs().size()))
laneSourceCoreIds.push_back(
receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast<size_t>(lane)]);
SmallVector<mlir::Value> mappedOutputBuffers;
mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size());
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
auto scalarReceiveMany =
pim::PimReceiveManyOp::create(builder,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp->getResultTypes(),
ValueRange(mappedOutputBuffers),
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
mapper.map(originalOutput, scalarOutput);
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
@@ -812,8 +919,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))

View File

@@ -116,7 +116,11 @@ public:
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);

View File

@@ -150,116 +150,105 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(sendManyOp);
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input);
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr);
}
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendManyOp::create(
rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs());
rewriter.eraseOp(sendManyOp);
}
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
Location loc,
TypeRange outputTypes) {
SmallVector<Type> tensorTypes;
tensorTypes.reserve(outputTypes.size());
for (Type outputType : outputTypes)
tensorTypes.push_back(outputType);
auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes));
return SmallVector<Value>(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end());
}
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
SmallVector<Value> replacements;
replacements.reserve(receiveManyOp.getNumResults());
rewriter.setInsertionPoint(receiveManyOp);
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId));
Value received =
PimReceiveOp::create(
rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
replacements.push_back(received);
}
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
auto receiveMany = PimReceiveManyOp::create(rewriter,
receiveManyOp.getLoc(),
receiveManyOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs());
}
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
pim::PimSendBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
mapper.lookup(input),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
rewriter.getDenseI32ArrayAttr(targetSlice));
}
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
SmallVector<Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(mappedInputs));
}
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getDenseI32ArrayAttr(sourceSlice))
.getOutput();
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes());
auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs()))
mapper.map(output, received);
}
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
inputType = tensorType;
}
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
rewriter.setInsertionPoint(extractRowsOp);
input = bufferization::ToTensorOp::create(
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
.getResult();
}
else {
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
return;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
rewriter.setInsertionPoint(extractRowsOp);
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
return;
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
replacements.push_back(rowSlice.getResult());
}
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes());
rewriter.replaceOp(extractRowsOp, ValueRange(replacements));
auto extractRows = pim::PimExtractRowsOp::create(rewriter,
extractRowsOp.getLoc(),
extractRowsOp.getResultTypes(),
extractRowsOp.getInput(),
ValueRange(outputBuffers));
rewriter.replaceOp(extractRowsOp, extractRows.getOutputs());
}
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(concatOp);
Value concatenated =
tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult();
Value concatenated = pim::PimConcatOp::create(rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
rewriter.getI64IntegerAttr(concatOp.getAxis()),
concatOp.getInputs(),
outputBuffer)
.getOutput();
rewriter.replaceOp(concatOp, concatenated);
}
@@ -282,34 +271,23 @@ static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewrit
}
}
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
funcOp.walk([&](spatial::SpatMapOp mapOp) {
if (mapOp->getParentOfType<pim::PimCoreOp>() || mapOp->getParentOfType<pim::PimCoreBatchOp>())
mapOps.push_back(mapOp);
});
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs());
rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin());
Value replacement = input;
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
replacements.push_back(replacement);
}
rewriter.replaceOp(mapOp, replacements);
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<pim::PimYieldOp>(yieldOp, yieldOp.getOutputs());
rewriter.replaceOp(mapOp, pimMap.getOutputs());
}
}
@@ -440,8 +418,28 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation *op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
};
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
};
auto getConcatOperands = [](Operation *op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp>(uses.begin()->getOwner()))
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
@@ -453,18 +451,19 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (auto concatOp = dyn_cast<tensor::ConcatOp>(currentUser)) {
while (isa<tensor::ConcatOp, pim::PimConcatOp>(currentUser)) {
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = concatOp.getDim();
for (Value operand : concatOp.getOperands().take_front(operandIndex))
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
auto concatType = dyn_cast<ShapedType>(concatOp.getResult().getType());
Value concatResult = getConcatResult(currentUser);
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatOp.getResult();
currentValue = concatResult;
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
@@ -638,7 +637,6 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
expandMapOps(funcOp, rewriter);
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
@@ -687,6 +685,8 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeBatchOp(computeBatchOp, rewriter);
}
lowerMapOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
@@ -1317,7 +1317,8 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
isa<func::ReturnOp, tensor::ConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isChannelUseChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
@@ -1341,6 +1342,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};

View File

@@ -6,6 +6,8 @@ add_subdirectory(Transforms/Bufferization)
add_pim_library(PimOps
PimOps.hpp
PimOps.cpp
PimOpsAsm.cpp
PimOpsVerify.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -50,9 +50,7 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
Variadic<PimTensor>:$inputs
);
let assemblyFormat = [{
`lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)`
}];
let hasCustomAssemblyFormat = 1;
}
def PimHaltOp : PimOp<"halt", [Terminator]> {
@@ -63,6 +61,48 @@ def PimHaltOp : PimOp<"halt", [Terminator]> {
}];
}
def PimYieldOp : PimOp<"yield", [Terminator]> {
let summary = "Yield results from a Pim region";
let arguments = (ins
Variadic<PimTensor>:$outputs
);
let hasCustomAssemblyFormat = 1;
}
def PimMapOp : PimOp<"map", [SingleBlock]> {
let summary = "Apply the same lane-local region to many independent tensors";
let arguments = (ins
Variadic<PimTensor>:$inputs
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Tensor Utilities
//===----------------------------------------------------------------------===//
def PimEmptyManyOp : PimOp<"empty_many", []> {
let summary = "Create many identical empty tensors";
let results = (outs
Variadic<AnyRankedTensor>:$outputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
@@ -81,6 +121,18 @@ def PimSendOp : PimOp<"send", []> {
}];
}
def PimSendManyOp : PimOp<"send_many", []> {
let summary = "Send multiple tensors to target cores";
let arguments = (ins
DenseI32ArrayAttr:$targetCoreIds,
Variadic<PimTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
@@ -90,9 +142,19 @@ def PimSendBatchOp : PimOp<"send_batch", []> {
DenseI32ArrayAttr:$targetCoreIds
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
}];
let hasCustomAssemblyFormat = 1;
}
def PimSendManyBatchOp : PimOp<"send_many_batch", []> {
let summary = "Send multiple per-lane tensors to target cores from a batched core";
let arguments = (ins
DenseI32ArrayAttr:$targetCoreIds,
Variadic<PimTensor>:$inputs
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
@@ -119,6 +181,28 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
}
def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
let summary = "Receive multiple tensors from source cores";
let arguments = (ins
Variadic<PimTensor>:$outputBuffers,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
@@ -138,9 +222,29 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
}
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
let hasCustomAssemblyFormat = 1;
}
def PimReceiveManyBatchOp : PimOp<"receive_many_batch", [DestinationStyleOpInterface]> {
let summary = "Receive multiple per-lane tensors from source cores into a batched core";
let arguments = (ins
Variadic<PimTensor>:$outputBuffers,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
@@ -247,6 +351,55 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
}];
}
//===----------------------------------------------------------------------===//
// Tensor utilities
//===----------------------------------------------------------------------===//
def PimExtractRowsOp : PimOp<"extract_rows", [DestinationStyleOpInterface]> {
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
let arguments = (ins
PimTensor:$input,
Variadic<PimTensor>:$outputBuffers
);
let results = (outs
Variadic<PimTensor>:$outputs
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBuffersMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimConcatOp : PimOp<"concat", [DestinationStyleOpInterface]> {
let summary = "Concatenate tensors";
let arguments = (ins
I64Attr:$axis,
Variadic<PimTensor>:$inputs,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//

View File

@@ -1,19 +1,5 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {

View File

@@ -0,0 +1,486 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
namespace {
using namespace onnx_mlir::compact_asm;
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
printer << " " << keyword << " ";
printCompressedIntegerList(printer, coreIds);
}
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
return parseCompressedIntegerList(parser, coreIds);
}
} // namespace
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
printer << " ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
printer << " -> ()";
}
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
int32_t laneCount = 0;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
return failure();
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Region* body = result.addRegion();
if (parser.parseRegion(*body))
return failure();
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|| parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute("operandSegmentSizes",
builder.getDenseI32ArrayAttr(
{static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
return failure();
}
return success();
}
void PimYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void PimEmptyManyOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getOutputs().front().getType());
printer << " x" << getOutputs().size();
}
ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
int64_t resultCount = 0;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)
|| parser.parseKeyword("x") || parser.parseInteger(resultCount))
return failure();
if (resultCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive");
SmallVector<Type> resultTypes(resultCount, outputType);
result.addTypes(resultTypes);
return success();
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimReceiveManyOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (parser.resolveOperand(input, inputType, result.operands)
|| parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());
printer << " into ";
printer.printOperand(getOutputBuffer());
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
printer << " : ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedTypeSequence(printer, TypeRange(getInputs()));
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimConcatOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t axis = 0;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
OpAsmParser::UnresolvedOperand outputBuffer;
SmallVector<Type> inputTypes;
Type outputType;
if (parser.parseKeyword("axis") || parser.parseInteger(axis) || parseCompressedOperandSequence(parser, inputs)
|| parser.parseKeyword("into") || parser.parseOperand(outputBuffer)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseLParen()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false) || parser.parseRParen()
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (result.attributes.get("axis"))
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperand(outputBuffer, outputType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -0,0 +1,268 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
namespace {
static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
if (coreIds.size() != valueCount)
return op->emitError("core id metadata length must match the number of values");
return success();
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) {
auto lhsShaped = dyn_cast<ShapedType>(lhs);
auto rhsShaped = dyn_cast<ShapedType>(rhs);
if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs))
return op->emitError(message);
if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape())
return op->emitError(message);
return success();
}
static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
Type firstType = types.front();
auto firstShapedType = dyn_cast<ShapedType>(firstType);
bool firstIsTensor = isa<RankedTensorType>(firstType);
bool firstIsMemRef = isa<MemRefType>(firstType);
for (Type type : types.drop_front())
if (type != firstType) {
auto shapedType = dyn_cast<ShapedType>(type);
if (!firstShapedType || !shapedType)
return op->emitError() << kind << " values must all have the same type";
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
return op->emitError() << kind << " values must all use the same shaped container kind";
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
return op->emitError() << kind << " values must all have the same shape and element type";
}
return success();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return failure();
return coreBatchOp.getLaneCount();
}
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
ArrayRef<int32_t> coreIds,
size_t valueCount) {
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside pim.core_batch");
if (coreIds.size() != valueCount * static_cast<size_t>(*laneCount))
return op->emitError("core id metadata length must match the number of values times parent laneCount");
return success();
}
} // namespace
LogicalResult PimEmptyManyOp::verify() {
if (getOutputs().empty())
return emitError("must produce at least one output");
Type firstType = getOutputs().front().getType();
auto firstTensorType = dyn_cast<RankedTensorType>(firstType);
if (!firstTensorType)
return emitError("outputs must all be ranked tensor types");
for (Value output : getOutputs().drop_front())
if (output.getType() != firstType)
return emitError("outputs must all have the same type");
return success();
}
LogicalResult PimMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with pim.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
return emitError("body yield type must match output type");
return success();
}
LogicalResult PimSendManyOp::verify() {
if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
}
LogicalResult PimSendManyBatchOp::verify() {
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch");
}
LogicalResult PimReceiveManyOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimReceiveManyBatchOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimExtractRowsOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
auto inputType = dyn_cast<ShapedType>(getInput().getType());
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
return emitError("input must be a rank-2 shaped type");
int64_t numRows = inputType.getShape()[0];
int64_t numCols = inputType.getShape()[1];
Type elementType = inputType.getElementType();
if (numRows >= 0 && static_cast<int64_t>(getOutputs().size()) != numRows)
return emitError("number of outputs must match the number of input rows");
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) {
if (failed(verifyCompatibleShapedTypes(
getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match")))
return failure();
auto outputType = dyn_cast<ShapedType>(output.getType());
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
return emitError("outputs must all be rank-2 shaped types");
if (!haveSameShapedContainerKind(getInput().getType(), output.getType()))
return emitError("outputs must use the same shaped container kind as the input");
if (outputType.getElementType() != elementType)
return emitError("output element types must match input element type");
auto outputShape = outputType.getShape();
if (outputShape[0] != 1)
return emitError("each output must have exactly one row");
if (numCols >= 0 && outputShape[1] != numCols)
return emitError("output column count must match input column count");
}
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!outputType || !outputType.hasRank())
return emitError("output must be a ranked shaped type");
int64_t axis = getAxis();
int64_t rank = outputType.getRank();
if (axis < 0 || axis >= rank)
return emitError("axis must be within the output rank");
int64_t concatenatedDimSize = 0;
bool concatenatedDimDynamic = false;
Type outputElementType = outputType.getElementType();
for (Value input : getInputs()) {
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType || !inputType.hasRank())
return emitError("inputs must be ranked shaped types");
if (!haveSameShapedContainerKind(input.getType(), getOutput().getType()))
return emitError("inputs and output must use the same shaped container kind");
if (inputType.getRank() != rank)
return emitError("all inputs must have the same rank as the output");
if (inputType.getElementType() != outputElementType)
return emitError("all inputs must have the same element type as the output");
for (int64_t dim = 0; dim < rank; ++dim) {
if (dim == axis)
continue;
int64_t inputDim = inputType.getDimSize(dim);
int64_t outputDim = outputType.getDimSize(dim);
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
return emitError("non-concatenated dimensions must match the output shape");
}
int64_t inputConcatDim = inputType.getDimSize(axis);
if (ShapedType::isDynamic(inputConcatDim)) {
concatenatedDimDynamic = true;
continue;
}
concatenatedDimSize += inputConcatDim;
}
int64_t outputConcatDim = outputType.getDimSize(axis);
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
return emitError("output concatenated dimension must equal the sum of input sizes");
return success();
}
} // namespace pim
} // namespace onnx_mlir

View File

@@ -173,6 +173,235 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<Receive
}
};
struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveManyOpInterface, PimReceiveManyOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveManyOp>(op);
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
SmallVector<Value> tensorResults;
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size());
tensorResults.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimReceiveManyOp::create(
rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr());
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
auto toTensor =
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
tensorResults.push_back(toTensor.getResult());
}
rewriter.replaceOp(receiveOp, tensorResults);
return success();
}
};
struct ReceiveManyBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveManyBatchOpInterface, PimReceiveManyBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveManyBatchOp>(op);
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
SmallVector<Value> tensorResults;
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
resultTypes.reserve(receiveOp.getOutputBuffers().size());
tensorResults.reserve(receiveOp.getOutputBuffers().size());
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimReceiveManyBatchOp::create(rewriter,
receiveOp.getLoc(),
TypeRange(resultTypes),
ValueRange(outputBuffers),
receiveOp.getSourceCoreIdsAttr());
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
auto toTensor =
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
tensorResults.push_back(toTensor.getResult());
}
rewriter.replaceOp(receiveOp, tensorResults);
return success();
}
};
struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractRowsOpInterface, PimExtractRowsOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto extractRowsOp = cast<PimExtractRowsOp>(op);
auto inputOpt = getBuffer(rewriter, extractRowsOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
SmallVector<Value> outputBuffers;
SmallVector<Type> resultTypes;
outputBuffers.reserve(extractRowsOp.getOutputBuffers().size());
resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
if (failed(outputBufferOpt))
return failure();
outputBuffers.push_back(*outputBufferOpt);
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimExtractRowsOp::create(rewriter,
extractRowsOp.getLoc(),
TypeRange(resultTypes),
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
ValueRange(outputBuffers));
rewriter.replaceOp(extractRowsOp, newOp.getOutputs());
return success();
}
};
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto concatOp = cast<PimConcatOp>(op);
SmallVector<Value> inputs;
inputs.reserve(concatOp.getInputs().size());
for (Value input : concatOp.getInputs()) {
auto inputOpt = getBuffer(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
}
auto outputBufferOpt = getBuffer(rewriter, concatOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimConcatOp>(
rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt);
return success();
}
};
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
return {};
return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType>
getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
return failure();
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
if (inputType)
return inputType;
auto shapedType = cast<ShapedType>(mapOp.getInputs().front().getType());
return BufferLikeType(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mapOp = cast<PimMapOp>(op);
SmallVector<Value> inputs;
SmallVector<Type> resultTypes;
inputs.reserve(mapOp.getInputs().size());
resultTypes.reserve(mapOp.getOutputs().size());
for (Value input : mapOp.getInputs()) {
if (isa<TensorType>(input.getType())) {
auto inputOpt = getBuffer(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(*inputOpt);
}
else {
inputs.push_back(input);
}
}
for (Value output : mapOp.getOutputs()) {
auto shapedType = cast<ShapedType>(output.getType());
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
}
rewriter.setInsertionPoint(mapOp);
auto newOp = PimMapOp::create(rewriter, mapOp.getLoc(), TypeRange(resultTypes), ValueRange(inputs));
rewriter.inlineRegionBefore(mapOp.getBody(), newOp.getBody(), newOp.getBody().begin());
for (Block& block : newOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
return failure();
rewriter.replaceOp(mapOp, newOp.getOutputs());
return success();
}
};
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return true;
@@ -435,9 +664,14 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimMapOp::attachInterface<MapOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);

View File

@@ -3,7 +3,9 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Threading.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
@@ -45,6 +47,23 @@ private:
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
{
SmallVector<pim::PimEmptyManyOp> emptyManyOps;
moduleOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { emptyManyOps.push_back(emptyManyOp); });
IRRewriter rewriter(moduleOp.getContext());
for (auto emptyManyOp : emptyManyOps) {
SmallVector<Value> replacementValues;
replacementValues.reserve(emptyManyOp.getOutputs().size());
rewriter.setInsertionPoint(emptyManyOp);
for (Value output : emptyManyOp.getOutputs()) {
auto outputType = cast<RankedTensorType>(output.getType());
replacementValues.push_back(
tensor::EmptyOp::create(rewriter, emptyManyOp.getLoc(), outputType.getShape(), outputType.getElementType()));
}
rewriter.replaceOp(emptyManyOp, replacementValues);
}
}
// Refactor this into a function
{
auto funcOp = getPimEntryFunc(moduleOp);

View File

@@ -8,6 +8,7 @@
#include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -23,23 +24,23 @@ enum class ListDelimiter {
};
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
return onnx_mlir::compact_asm::parseOpenDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
return onnx_mlir::compact_asm::parseOptionalCloseDelimiter(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
onnx_mlir::compact_asm::printOpenDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
onnx_mlir::compact_asm::printCloseDelimiter(
printer, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter));
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
@@ -51,31 +52,8 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
break;
if (parser.parseComma())
return failure();
}
return success();
return onnx_mlir::compact_asm::parseCompressedRepeatedList(
parser, static_cast<onnx_mlir::compact_asm::ListDelimiter>(delimiter), entries, parseEntry);
}
template <typename IntT>
@@ -388,156 +366,32 @@ static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize);
}
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize);
}
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize) << "]";
onnx_mlir::compact_asm::printValueTupleRun(
printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
printer << "[";
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize) << "]";
onnx_mlir::compact_asm::printTypeTupleRun(
printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square);
}
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parser.parseRSquare();
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands);
}
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
if (parser.parseLSquare())
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parser.parseRSquare();
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parser.parseOptionalRSquare()))
return success();
if (parser.parseComma())
return failure();
}
return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList(
parser, onnx_mlir::compact_asm::ListDelimiter::Square, types);
}
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,

View File

@@ -1,9 +1,12 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
@@ -31,6 +34,35 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<pim::PimMapOp> mapOps;
funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
}
rewriter.replaceOp(mapOp, replacements);
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
@@ -41,13 +73,15 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
OpBuilder rewriter(moduleOp.getContext());
IRRewriter rewriter(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
expandPimMapOps(funcOp, rewriter);
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
@@ -113,6 +147,45 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
}
}
}
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimExtractRowsOp, pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op)) {
auto inputType = dyn_cast<ShapedType>(extractRowsOp.getInput().getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) {
extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input");
hasFailure = true;
continue;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacementRows;
replacementRows.reserve(extractRowsOp.getOutputs().size());
for (auto rowIndex : llvm::seq<size_t>(0, extractRowsOp.getOutputs().size())) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacementRows.push_back(memref::SubViewOp::create(
rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides)
.getResult());
}
extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows));
extractRowsOp->erase();
continue;
}
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {