diff --git a/src/PIM/Common/IR/CompactAsmUtils.hpp b/src/PIM/Common/IR/CompactAsmUtils.hpp new file mode 100644 index 0000000..1a1fc84 --- /dev/null +++ b/src/PIM/Common/IR/CompactAsmUtils.hpp @@ -0,0 +1,755 @@ +#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP +#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace onnx_mlir { +namespace compact_asm { + +using namespace mlir; + +enum class ListDelimiter { + Square, + Paren +}; + +inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseLSquare(); + return parser.parseLParen(); +} + +inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseOptionalRSquare(); + return parser.parseOptionalRParen(); +} + +inline void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "[" : "("); +} + +inline void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "]" : ")"); +} + +template +inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& entries, + ParseEntryFn parseEntry) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + EntryT entry; + if (parseEntry(entry)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + entries.push_back(entry); + + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +template +inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& values) { + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + if (succeeded(parser.parseOptionalLParen())) { + SmallVector subgroup; + if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(values, subgroup); + } + else { + int64_t first = 0; + if (parser.parseInteger(first)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + int64_t last = 0; + if (parser.parseInteger(last) || last < first) + return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); + + int64_t step = 1; + if (succeeded(parser.parseOptionalKeyword("by"))) { + if (parser.parseInteger(step) || step <= 0) + return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); + } + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + if ((last - first) % step != 0) { + return parser.emitError( + parser.getCurrentLocation(), "range end must be reachable from start using the given step"); + } + + for (int64_t value = first; value <= last; value += step) + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(value)); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(first)); + } + } + + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +template +inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& values) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + return parseCompressedIntegerEntries(parser, delimiter, values); +} + +template +inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { + for (size_t index = 0; index < entries.size();) { + size_t runEnd = index + 1; + while (runEnd < entries.size() && entries[runEnd] == entries[index]) + ++runEnd; + + if (index != 0) + printer << ", "; + printEntry(entries[index]); + size_t runLength = runEnd - index; + if (runLength > 1) + printer << " x" << runLength; + index = runEnd; + } +} + +template +inline void printCompressedIntegerSequence(OpAsmPrinter& printer, + ArrayRef values, + ListDelimiter delimiter) { + struct FlatCompression { + enum class Kind { + Single, + EqualRun, + Progression + }; + + Kind kind = Kind::Single; + size_t covered = 1; + size_t repeatCount = 1; + size_t progressionValueCount = 1; + int64_t step = 1; + IntT firstValue {}; + IntT lastValue {}; + }; + + auto computeFlatCompression = [&](size_t start) { + FlatCompression compression; + compression.firstValue = values[start]; + compression.lastValue = values[start]; + + auto findEqualRunEnd = [&](size_t runStart) { + size_t runEnd = runStart + 1; + while (runEnd < values.size() && values[runEnd] == values[runStart]) + ++runEnd; + return runEnd; + }; + + size_t firstRunEnd = findEqualRunEnd(start); + compression.repeatCount = firstRunEnd - start; + size_t progressionEnd = firstRunEnd; + int64_t step = 0; + IntT lastValue = values[start]; + + if (firstRunEnd < values.size()) { + size_t secondRunEnd = findEqualRunEnd(firstRunEnd); + step = static_cast(values[firstRunEnd]) - static_cast(values[start]); + if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) { + progressionEnd = secondRunEnd; + lastValue = values[firstRunEnd]; + size_t currentRunStart = secondRunEnd; + while (currentRunStart < values.size()) { + size_t currentRunEnd = findEqualRunEnd(currentRunStart); + if (currentRunEnd - currentRunStart != compression.repeatCount) + break; + if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) + break; + lastValue = values[currentRunStart]; + progressionEnd = currentRunEnd; + currentRunStart = currentRunEnd; + } + } + else { + step = 0; + } + } + + compression.covered = 1; + if (progressionEnd > firstRunEnd) { + size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount; + if (progressionValueCount >= 3) { + compression.kind = FlatCompression::Kind::Progression; + compression.covered = progressionEnd - start; + compression.progressionValueCount = progressionValueCount; + compression.step = step; + compression.lastValue = lastValue; + return compression; + } + } + + if (compression.repeatCount > 1) { + compression.kind = FlatCompression::Kind::EqualRun; + compression.covered = compression.repeatCount; + return compression; + } + + return compression; + }; + + auto findRepeatedSublist = [&](size_t start) { + size_t bestLength = 0; + size_t bestRepeatCount = 1; + size_t remaining = values.size() - start; + + for (size_t length = 2; length * 2 <= remaining; ++length) { + size_t repeatCount = 1; + ArrayRef candidate = values.slice(start, length); + while (start + (repeatCount + 1) * length <= values.size() + && llvm::equal(candidate, values.slice(start + repeatCount * length, length))) { + ++repeatCount; + } + + if (repeatCount <= 1) + continue; + + size_t covered = length * repeatCount; + size_t bestCovered = bestLength * bestRepeatCount; + if (covered > bestCovered || (covered == bestCovered && length < bestLength)) { + bestLength = length; + bestRepeatCount = repeatCount; + } + } + + return std::pair(bestLength, bestRepeatCount); + }; + + printOpenDelimiter(printer, delimiter); + for (size_t index = 0; index < values.size();) { + if (index != 0) + printer << ", "; + + FlatCompression flat = computeFlatCompression(index); + auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); + size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; + if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { + printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); + printer << " x" << sublistRepeatCount; + index += repeatedSublistCoverage; + continue; + } + + switch (flat.kind) { + case FlatCompression::Kind::Progression: + printer << flat.firstValue << " to " << flat.lastValue; + if (flat.step != 1) + printer << " by " << flat.step; + if (flat.repeatCount > 1) + printer << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::EqualRun: + printer << flat.firstValue << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::Single: + printer << flat.firstValue; + index += flat.covered; + break; + } + } + printCloseDelimiter(printer, delimiter); +} + +template +inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { + return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values); +} + +template +inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { + printCompressedIntegerSequence(printer, values, ListDelimiter::Square); +} + +inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { + for (size_t index = 0; index < values.size();) { + size_t equalRunEnd = index + 1; + while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) + ++equalRunEnd; + + if (index != 0) + printer << ", "; + if (equalRunEnd - index > 1) { + printer.printOperand(values[index]); + printer << " x" << (equalRunEnd - index); + index = equalRunEnd; + continue; + } + + size_t rangeEnd = index + 1; + if (auto firstResult = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextResult = dyn_cast(values[rangeEnd]); + if (!nextResult || nextResult.getOwner() != firstResult.getOwner() + || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) { + break; + } + ++rangeEnd; + } + } + else if (auto firstArg = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextArg = dyn_cast(values[rangeEnd]); + if (!nextArg || nextArg.getOwner() != firstArg.getOwner() + || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) { + break; + } + ++rangeEnd; + } + } + + printer.printOperand(values[index]); + if (rangeEnd - index >= 3) { + printer << " to "; + printer.printOperand(values[rangeEnd - 1]); + } + else if (rangeEnd - index == 2) { + printer << ", "; + printer.printOperand(values[index + 1]); + } + index = rangeEnd; + } +} + +inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { + printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); +} + +inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printCompressedValueSequence(printer, values); + printCloseDelimiter(printer, delimiter); +} + +inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printCompressedTypeSequence(printer, types); + printCloseDelimiter(printer, delimiter); +} + +inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, + SmallVectorImpl& types, + bool allowEmpty) { + Type firstType; + OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); + if (!firstTypeResult.has_value()) { + if (allowEmpty) + return success(); + return parser.emitError(parser.getCurrentLocation(), "expected type"); + } + if (failed(*firstTypeResult)) + return failure(); + + auto appendType = [&](Type type) -> ParseResult { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + types.push_back(type); + return success(); + }; + + if (appendType(firstType)) + return failure(); + while (succeeded(parser.parseOptionalComma())) { + Type nextType; + if (parser.parseType(nextType) || appendType(nextType)) + return failure(); + } + return success(); +} + +inline ParseResult parseCompressedOperandEntryWithFirst( + OpAsmParser& parser, + OpAsmParser::UnresolvedOperand firstOperand, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::UnresolvedOperand lastOperand; + if (parser.parseOperand(lastOperand)) + return failure(); + if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) + return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); + for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) + operands.push_back({firstOperand.location, firstOperand.name, number}); + return success(); + } + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + operands.push_back(firstOperand); + return success(); +} + +inline ParseResult parseOneCompressedOperandEntry( + OpAsmParser& parser, + SmallVectorImpl& operands) { + OpAsmParser::UnresolvedOperand firstOperand; + if (parser.parseOperand(firstOperand)) + return failure(); + return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); +} + +inline ParseResult parseCompressedOperandList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +inline ParseResult parseCompressedOperandSequence( + OpAsmParser& parser, + SmallVectorImpl& operands) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + return success(); +} + +inline ParseResult parseCompressedTypeList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& types) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false)) + return failure(); + return parseOptionalCloseDelimiter(parser, delimiter); +} + +inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { + if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0) + return false; + + SmallVector valueVec(values.begin(), values.end()); + ArrayRef tuple(valueVec.data(), tupleSize); + for (size_t index = tupleSize; index < values.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(valueVec).slice(index, tupleSize))) + return false; + return true; +} + +inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { + if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0) + return false; + + SmallVector typeVec(types.begin(), types.end()); + ArrayRef tuple(typeVec.data(), tupleSize); + for (size_t index = tupleSize; index < types.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(typeVec).slice(index, tupleSize))) + return false; + return true; +} + +inline void printValueTupleRun(OpAsmPrinter& printer, + ValueRange values, + size_t tupleSize, + ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printOperand(values[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (values.size() / tupleSize); + printCloseDelimiter(printer, delimiter); +} + +inline void printTypeTupleRun(OpAsmPrinter& printer, + TypeRange types, + size_t tupleSize, + ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printType(types[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (types.size() / tupleSize); + printCloseDelimiter(printer, delimiter); +} + +inline ParseResult parseCompressedOrTupleOperandList( + OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleOperands; + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleOperands.clear(); + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + } + return parseOptionalCloseDelimiter(parser, delimiter); + } + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& types) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleTypes; + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleTypes.clear(); + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + } + return parseOptionalCloseDelimiter(parser, delimiter); + } + + while (true) { + Type type; + if (parser.parseType(type)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + types.push_back(type); + + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) { + if (block.getNumArguments() == 0) { + printer << "() = ()"; + return; + } + + if (block.getNumArguments() == 1) { + printer.printOperand(block.getArgument(0)); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); + return; + } + + printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); +} + +inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser, + OpAsmParser::Argument firstArgument, + SmallVectorImpl& arguments) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::Argument lastArgument; + if (parser.parseArgument(lastArgument)) + return failure(); + if (firstArgument.ssaName.name != lastArgument.ssaName.name + || firstArgument.ssaName.number > lastArgument.ssaName.number) { + return parser.emitError(parser.getCurrentLocation(), "invalid argument range"); + } + for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) { + OpAsmParser::Argument argument; + argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number}; + arguments.push_back(argument); + } + return success(); + } + + arguments.push_back(firstArgument); + return success(); +} + +inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser, + SmallVectorImpl& arguments) { + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument)) + return failure(); + return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments); +} + +inline void applyArgumentTypes(ArrayRef inputTypes, SmallVectorImpl& arguments) { + for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes)) + argument.type = inputType; +} + +inline ParseResult parseArgumentBindings(OpAsmParser& parser, + SmallVectorImpl& arguments, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalLParen())) { + if (succeeded(parser.parseOptionalRParen())) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + return success(); + } + + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + if (parser.parseRParen() || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) { + return failure(); + } + return success(); + } + + OpAsmParser::Argument argument; + if (parser.parseArgument(argument) || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) { + return failure(); + } + arguments.push_back(argument); + return success(); +} + +} // namespace compact_asm +} // namespace onnx_mlir + +#endif diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 22fdca7..fed460b 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -37,6 +37,11 @@ using namespace llvm; using namespace mlir; using namespace onnx_mlir; +static size_t getValueSizeInBytes(mlir::Value value) { + auto type = cast(value.getType()); + return type.getNumElements() * type.getElementTypeBitWidth() / 8; +} + MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); @@ -382,10 +387,75 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); } +void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const { + for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds())) + emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer)); +} + void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); } +void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const { + for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) + emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input)); +} + +void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const { + auto inputType = cast(extractRowsOp.getInput().getType()); + assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input"); + + size_t elementSize = inputType.getElementTypeBitWidth() / 8; + size_t rowSizeInBytes = static_cast(inputType.getDimSize(1)) * elementSize; + size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge); + + for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers())) + emitMemCopyOp("lmv", + addressOf(outputBuffer, knowledge), + 0, + inputAddr, + rowIndex * rowSizeInBytes, + rowSizeInBytes, + "len"); +} + +void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const { + auto outputType = cast(concatOp.getOutputBuffer().getType()); + assert(outputType.hasStaticShape() && "concat codegen requires static output shape"); + + int64_t axis = concatOp.getAxis(); + ArrayRef outputShape = outputType.getShape(); + size_t elementSize = outputType.getElementTypeBitWidth() / 8; + size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge); + + size_t outerCount = 1; + for (int64_t dim = 0; dim < axis; ++dim) + outerCount *= static_cast(outputShape[dim]); + + size_t innerCount = 1; + for (size_t dim = static_cast(axis) + 1; dim < outputShape.size(); ++dim) + innerCount *= static_cast(outputShape[dim]); + + size_t outputConcatDim = static_cast(outputShape[axis]); + size_t concatOffset = 0; + for (mlir::Value input : concatOp.getInputs()) { + auto inputType = cast(input.getType()); + assert(inputType.hasStaticShape() && "concat codegen requires static input shapes"); + + size_t inputConcatDim = static_cast(inputType.getDimSize(axis)); + size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize; + size_t inputAddr = addressOf(input, knowledge); + + for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) { + size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize; + size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize; + emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len"); + } + + concatOffset += inputConcatDim; + } +} + template void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, @@ -396,11 +466,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) } -static size_t getValueSizeInBytes(mlir::Value value) { - auto type = cast(value.getType()); - return type.getNumElements() * type.getElementTypeBitWidth() / 8; -} - void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const { auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge); auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge); @@ -682,6 +747,25 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor continue; } + if (auto sendManyBatchOp = dyn_cast(op)) { + SmallVector laneTargetCoreIds; + laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size()); + for (auto valueIndex : llvm::seq(0, sendManyBatchOp.getInputs().size())) + laneTargetCoreIds.push_back( + sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast(lane)]); + + SmallVector mappedInputs; + mappedInputs.reserve(sendManyBatchOp.getInputs().size()); + for (mlir::Value input : sendManyBatchOp.getInputs()) + mappedInputs.push_back(mapper.lookup(input)); + + pim::PimSendManyOp::create(builder, + sendManyBatchOp.getLoc(), + builder.getDenseI32ArrayAttr(laneTargetCoreIds), + ValueRange(mappedInputs)); + continue; + } + if (auto receiveBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveOp::create(builder, @@ -694,6 +778,29 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor continue; } + if (auto receiveManyBatchOp = dyn_cast(op)) { + SmallVector laneSourceCoreIds; + laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size()); + for (auto valueIndex : llvm::seq(0, receiveManyBatchOp.getOutputs().size())) + laneSourceCoreIds.push_back( + receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast(lane)]); + + SmallVector mappedOutputBuffers; + mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size()); + for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers()) + mappedOutputBuffers.push_back(mapper.lookup(outputBuffer)); + + auto scalarReceiveMany = + pim::PimReceiveManyOp::create(builder, + receiveManyBatchOp.getLoc(), + receiveManyBatchOp->getResultTypes(), + ValueRange(mappedOutputBuffers), + builder.getDenseI32ArrayAttr(laneSourceCoreIds)); + for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs())) + mapper.map(originalOutput, scalarOutput); + continue; + } + if (auto memcpBatchOp = dyn_cast(op)) { mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource()); if (!hostSource) @@ -812,8 +919,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenLmvOp(lmvOp, knowledge); else if (auto receiveOp = dyn_cast(op)) coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); + else if (auto receiveManyOp = dyn_cast(op)) + coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge); else if (auto sendOp = dyn_cast(op)) coreCodeGen.codeGenSendOp(sendOp, knowledge); + else if (auto sendManyOp = dyn_cast(op)) + coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge); + else if (auto extractRowsOp = dyn_cast(op)) + coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge); + else if (auto concatOp = dyn_cast(op)) + coreCodeGen.codeGenConcatOp(concatOp, knowledge); else if (auto vmmOp = dyn_cast(op)) coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true, knowledge); else if (auto mvmOp = dyn_cast(op)) diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 83ca886..ae11df9 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -116,7 +116,11 @@ public: void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; + void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; + void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const; + void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const; + void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const; template void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index e0f44c8..82401af 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -150,116 +150,105 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(sendManyOp); - for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input); - auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)); - PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr); - } + SmallVector targetCoreIds; + targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size()); + for (int32_t targetCoreId : sendManyOp.getTargetCoreIds()) + targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); + PimSendManyOp::create( + rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs()); rewriter.eraseOp(sendManyOp); } +static SmallVector createManyEmptyTensorsLike(IRRewriter& rewriter, + Location loc, + TypeRange outputTypes) { + SmallVector tensorTypes; + tensorTypes.reserve(outputTypes.size()); + for (Type outputType : outputTypes) + tensorTypes.push_back(outputType); + + auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes)); + return SmallVector(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end()); +} + static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { - SmallVector replacements; - replacements.reserve(receiveManyOp.getNumResults()); - rewriter.setInsertionPoint(receiveManyOp); - for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { - auto outputType = cast(output.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output); - auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)); - Value received = - PimReceiveOp::create( - rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) - .getOutput(); - replacements.push_back(received); - } + SmallVector sourceCoreIds; + sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size()); + for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds()) + sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); + SmallVector outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes()); - rewriter.replaceOp(receiveManyOp, ValueRange(replacements)); + auto receiveMany = PimReceiveManyOp::create(rewriter, + receiveManyOp.getLoc(), + receiveManyOp.getResultTypes(), + ValueRange(outputBuffers), + rewriter.getDenseI32ArrayAttr(sourceCoreIds)); + rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs()); } static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, int32_t laneCount, IRMapping& mapper, IRRewriter& rewriter) { - auto targetCoreIds = sendManyBatchOp.getTargetCoreIds(); - for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) { - size_t metadataOffset = valueIndex * static_cast(laneCount); - auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount); - pim::PimSendBatchOp::create(rewriter, - sendManyBatchOp.getLoc(), - mapper.lookup(input), - getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)), - rewriter.getDenseI32ArrayAttr(targetSlice)); - } + SmallVector targetCoreIds; + targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size()); + for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds()) + targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); + SmallVector mappedInputs; + mappedInputs.reserve(sendManyBatchOp.getInputs().size()); + for (Value input : sendManyBatchOp.getInputs()) + mappedInputs.push_back(mapper.lookup(input)); + pim::PimSendManyBatchOp::create(rewriter, + sendManyBatchOp.getLoc(), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(mappedInputs)); } static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, int32_t laneCount, IRMapping& mapper, IRRewriter& rewriter) { - auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds(); - for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { - size_t metadataOffset = valueIndex * static_cast(laneCount); - auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount); - auto outputType = cast(output.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType); - auto received = pim::PimReceiveBatchOp::create(rewriter, - receiveManyBatchOp.getLoc(), - outputBuffer.getType(), - outputBuffer, - getTensorSizeInBytesAttr(rewriter, output), - rewriter.getDenseI32ArrayAttr(sourceSlice)) - .getOutput(); + SmallVector sourceCoreIds; + sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size()); + for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds()) + sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); + SmallVector outputBuffers = + createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes()); + + auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter, + receiveManyBatchOp.getLoc(), + receiveManyBatchOp.getResultTypes(), + ValueRange(outputBuffers), + rewriter.getDenseI32ArrayAttr(sourceCoreIds)); + for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs())) mapper.map(output, received); - } } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { - Value input = extractRowsOp.getInput(); - RankedTensorType inputType; - if (auto tensorType = dyn_cast(input.getType())) { - inputType = tensorType; - } - else if (auto memRefType = dyn_cast(input.getType())) { - inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType()); - rewriter.setInsertionPoint(extractRowsOp); - input = bufferization::ToTensorOp::create( - rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr()) - .getResult(); - } - else { - extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering"); - return; - } - int64_t numCols = inputType.getDimSize(1); - - SmallVector replacements; - replacements.reserve(extractRowsOp.getNumResults()); - rewriter.setInsertionPoint(extractRowsOp); - for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { - auto outputType = dyn_cast(output.getType()); - if (!outputType) { - extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering"); - return; - } - SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), - rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto rowSlice = - tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); - replacements.push_back(rowSlice.getResult()); - } + SmallVector outputBuffers = + createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes()); - rewriter.replaceOp(extractRowsOp, ValueRange(replacements)); + auto extractRows = pim::PimExtractRowsOp::create(rewriter, + extractRowsOp.getLoc(), + extractRowsOp.getResultTypes(), + extractRowsOp.getInput(), + ValueRange(outputBuffers)); + rewriter.replaceOp(extractRowsOp, extractRows.getOutputs()); } static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(concatOp); - Value concatenated = - tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult(); + auto outputType = cast(concatOp.getOutput().getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult(); + Value concatenated = pim::PimConcatOp::create(rewriter, + concatOp.getLoc(), + concatOp.getOutput().getType(), + rewriter.getI64IntegerAttr(concatOp.getAxis()), + concatOp.getInputs(), + outputBuffer) + .getOutput(); rewriter.replaceOp(concatOp, concatenated); } @@ -282,34 +271,23 @@ static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewrit } } -static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { +static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector mapOps; - funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); }); + funcOp.walk([&](spatial::SpatMapOp mapOp) { + if (mapOp->getParentOfType() || mapOp->getParentOfType()) + mapOps.push_back(mapOp); + }); for (auto mapOp : mapOps) { Block& body = mapOp.getBody().front(); - auto yieldOp = cast(body.getTerminator()); - - SmallVector replacements; - replacements.reserve(mapOp.getInputs().size()); rewriter.setInsertionPoint(mapOp); - for (Value input : mapOp.getInputs()) { - IRMapping mapping; - mapping.map(body.getArgument(0), input); + auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs()); + rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin()); - Value replacement = input; - for (Operation& op : body.without_terminator()) { - Operation* cloned = rewriter.clone(op, mapping); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapping.map(originalResult, clonedResult); - rewriter.setInsertionPointAfter(cloned); - } - - replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); - replacements.push_back(replacement); - } - - rewriter.replaceOp(mapOp, replacements); + auto yieldOp = cast(body.getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getOutputs()); + rewriter.replaceOp(mapOp, pimMap.getOutputs()); } } @@ -440,8 +418,28 @@ static std::optional analyzeReturnUse(Value value) { } static std::optional analyzeConcatReturnUse(Value value) { + auto getConcatResult = [](Operation *op) -> Value { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getResult(); + if (auto pimConcat = dyn_cast(op)) + return pimConcat.getOutput(); + return {}; + }; + auto getConcatAxis = [](Operation *op) -> std::optional { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getDim(); + if (auto pimConcat = dyn_cast(op)) + return pimConcat.getAxis(); + return std::nullopt; + }; + auto getConcatOperands = [](Operation *op) -> OperandRange { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getOperands(); + return cast(op).getInputs(); + }; + auto uses = value.getUses(); - if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) + if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) return std::nullopt; auto valueType = dyn_cast(value.getType()); @@ -453,18 +451,19 @@ static std::optional analyzeConcatReturnUse(Value value) { Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); - while (auto concatOp = dyn_cast(currentUser)) { + while (isa(currentUser)) { size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); - int64_t axis = concatOp.getDim(); - for (Value operand : concatOp.getOperands().take_front(operandIndex)) + int64_t axis = *getConcatAxis(currentUser); + for (Value operand : getConcatOperands(currentUser).take_front(operandIndex)) sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; - auto concatType = dyn_cast(concatOp.getResult().getType()); + Value concatResult = getConcatResult(currentUser); + auto concatType = dyn_cast(concatResult.getType()); if (!concatType || !concatType.hasStaticShape()) return std::nullopt; concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); - currentValue = concatOp.getResult(); + currentValue = concatResult; auto currentUses = currentValue.getUses(); if (rangeLength(currentUses) != 1) return std::nullopt; @@ -638,7 +637,6 @@ void SpatialToPimPass::runOnOperation() { func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); - expandMapOps(funcOp, rewriter); ConversionTarget target(*ctx); target.addLegalDialect receiveOps; for (auto op : funcOp.getOps()) receiveOps.push_back(op); @@ -1317,7 +1317,8 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = - isa(onlyUser) || isChannelUseChainOp(onlyUser); + isa(onlyUser) + || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) return; @@ -1341,6 +1342,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri markOpToRemove(concatOp); for (Value operand : concatOp.getOperands()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(concatOp); + for (Value operand : concatOp.getInputs()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); } }; diff --git a/src/PIM/Dialect/Pim/CMakeLists.txt b/src/PIM/Dialect/Pim/CMakeLists.txt index 80373ca..26e9974 100644 --- a/src/PIM/Dialect/Pim/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/CMakeLists.txt @@ -6,6 +6,8 @@ add_subdirectory(Transforms/Bufferization) add_pim_library(PimOps PimOps.hpp PimOps.cpp + PimOpsAsm.cpp + PimOpsVerify.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 3ecc353..7e27354 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -50,9 +50,7 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi Variadic:$inputs ); - let assemblyFormat = [{ - `lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)` - }]; + let hasCustomAssemblyFormat = 1; } def PimHaltOp : PimOp<"halt", [Terminator]> { @@ -63,6 +61,48 @@ def PimHaltOp : PimOp<"halt", [Terminator]> { }]; } +def PimYieldOp : PimOp<"yield", [Terminator]> { + let summary = "Yield results from a Pim region"; + + let arguments = (ins + Variadic:$outputs + ); + + let hasCustomAssemblyFormat = 1; +} + +def PimMapOp : PimOp<"map", [SingleBlock]> { + let summary = "Apply the same lane-local region to many independent tensors"; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +//===----------------------------------------------------------------------===// +// Tensor Utilities +//===----------------------------------------------------------------------===// + +def PimEmptyManyOp : PimOp<"empty_many", []> { + let summary = "Create many identical empty tensors"; + + let results = (outs + Variadic:$outputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// @@ -81,6 +121,18 @@ def PimSendOp : PimOp<"send", []> { }]; } +def PimSendManyOp : PimOp<"send_many", []> { + let summary = "Send multiple tensors to target cores"; + + let arguments = (ins + DenseI32ArrayAttr:$targetCoreIds, + Variadic:$inputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimSendBatchOp : PimOp<"send_batch", []> { let summary = "Send a per-lane tensor to target cores from a batched core"; @@ -90,9 +142,19 @@ def PimSendBatchOp : PimOp<"send_batch", []> { DenseI32ArrayAttr:$targetCoreIds ); - let assemblyFormat = [{ - `(` $input `)` attr-dict `:` type($input) `->` `(` `)` - }]; + let hasCustomAssemblyFormat = 1; +} + +def PimSendManyBatchOp : PimOp<"send_many_batch", []> { + let summary = "Send multiple per-lane tensors to target cores from a batched core"; + + let arguments = (ins + DenseI32ArrayAttr:$targetCoreIds, + Variadic:$inputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { @@ -119,6 +181,28 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { }]; } +def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> { + let summary = "Receive multiple tensors from source cores"; + + let arguments = (ins + Variadic:$outputBuffers, + DenseI32ArrayAttr:$sourceCoreIds + ); + + let results = (outs + Variadic:$outputs + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBuffersMutable(); + } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { let summary = "Receive per-lane tensors from source cores into a batched core"; @@ -138,9 +222,29 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { } }]; - let assemblyFormat = [{ - `(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output) + let hasCustomAssemblyFormat = 1; +} + +def PimReceiveManyBatchOp : PimOp<"receive_many_batch", [DestinationStyleOpInterface]> { + let summary = "Receive multiple per-lane tensors from source cores into a batched core"; + + let arguments = (ins + Variadic:$outputBuffers, + DenseI32ArrayAttr:$sourceCoreIds + ); + + let results = (outs + Variadic:$outputs + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBuffersMutable(); + } }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { @@ -247,6 +351,55 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> { }]; } +//===----------------------------------------------------------------------===// +// Tensor utilities +//===----------------------------------------------------------------------===// + +def PimExtractRowsOp : PimOp<"extract_rows", [DestinationStyleOpInterface]> { + let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; + + let arguments = (ins + PimTensor:$input, + Variadic:$outputBuffers + ); + + let results = (outs + Variadic:$outputs + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBuffersMutable(); + } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def PimConcatOp : PimOp<"concat", [DestinationStyleOpInterface]> { + let summary = "Concatenate tensors"; + + let arguments = (ins + I64Attr:$axis, + Variadic:$inputs, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Pim/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp index 1c59c9a..5168fda 100644 --- a/src/PIM/Dialect/Pim/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -1,19 +1,5 @@ -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" - -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallBitVector.h" - #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -using namespace mlir; - namespace onnx_mlir { namespace pim { diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp new file mode 100644 index 0000000..dc491ca --- /dev/null +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -0,0 +1,486 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Value.h" + +#include "llvm/Support/LogicalResult.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace pim { + +namespace { +using namespace onnx_mlir::compact_asm; + +static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { + return parser.getBuilder().getDenseI32ArrayAttr(values); +} + +static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef coreIds) { + printer << " " << keyword << " "; + printCompressedIntegerList(printer, coreIds); +} + +static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl& coreIds) { + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + return parseCompressedIntegerList(parser, coreIds); +} + +} // namespace + +void PimCoreBatchOp::print(OpAsmPrinter& printer) { + printer << " lanes " << getLaneCount() << " "; + size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; + if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) + printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren); + else + printCompressedValueList(printer, getWeights(), ListDelimiter::Paren); + printer << " "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Square); + + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) + printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef()); + + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); + printer << " : "; + if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) + printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren); + else + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren); + printer << " "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square); + printer << " -> ()"; +} + +ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { + int32_t laneCount = 0; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector coreIds; + + if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount) + || parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights) + || parseCompressedOperandList(parser, ListDelimiter::Square, inputs)) + return failure(); + + bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds")); + if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + Region* body = result.addRegion(); + if (parser.parseRegion(*body)) + return failure(); + + if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes) + || parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow() + || parser.parseLParen() || parser.parseRParen()) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "coreIds cannot be specified both positionally and in attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr( + {static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreIds) + result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) { + return failure(); + } + return success(); +} + +void PimYieldOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getOutputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputs; + SmallVector outputTypes; + + OpAsmParser::UnresolvedOperand firstOutput; + OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); + if (firstOutputResult.has_value()) { + if (failed(*firstOutputResult)) + return failure(); + if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, outputs)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (outputs.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); + + return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); +} + +void PimMapOp::print(OpAsmPrinter& printer) { + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInputs().front().getType()); + printer << " -> "; + printer.printType(getOutputs().front().getType()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector regionArgs; + SmallVector inputs; + Type inputType; + Type outputType; + + if (parseArgumentBindings(parser, regionArgs, inputs)) + return failure(); + if (inputs.empty()) + return parser.emitError(parser.getCurrentLocation(), "map requires at least one input"); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + SmallVector inputTypes(inputs.size(), inputType); + SmallVector outputTypes(inputs.size(), outputType); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + applyArgumentTypes(inputTypes, regionArgs); + Region* body = result.addRegion(); + return parser.parseRegion(*body, regionArgs); +} + +void PimEmptyManyOp::print(OpAsmPrinter& printer) { + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getOutputs().front().getType()); + printer << " x" << getOutputs().size(); +} + +ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) { + Type outputType; + int64_t resultCount = 0; + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType) + || parser.parseKeyword("x") || parser.parseInteger(resultCount)) + return failure(); + + if (resultCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive"); + + SmallVector resultTypes(resultCount, outputType); + result.addTypes(resultTypes); + return success(); +} + +void PimSendBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printCoreIdList(printer, "to", getTargetCoreIds()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getInput().getType()); +} + +ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector targetCoreIds; + + if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "targetCoreIds cannot be specified both positionally and in attr-dict"); + if (!targetCoreIds.empty()) + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + + return parser.resolveOperand(input, inputType, result.operands); +} + +void PimSendManyOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printCoreIdList(printer, "to", getTargetCoreIds()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "targetCoreIds cannot be specified both positionally and in attr-dict"); + if (!targetCoreIds.empty()) + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + +void PimSendManyBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printCoreIdList(printer, "to", getTargetCoreIds()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "targetCoreIds cannot be specified both positionally and in attr-dict"); + if (!targetCoreIds.empty()) + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + +void PimReceiveManyOp::print(OpAsmPrinter& printer) { + printCoreIdList(printer, "from", getSourceCoreIds()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printCompressedValueSequence(printer, getOutputBuffers()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputBuffers; + SmallVector outputTypes; + SmallVector sourceCoreIds; + + if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() + || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (outputBuffers.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); + if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "sourceCoreIds cannot be specified both positionally and in attr-dict"); + if (!sourceCoreIds.empty()) + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + + if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + return success(); +} + +void PimReceiveBatchOp::print(OpAsmPrinter& printer) { + printCoreIdList(printer, "from", getSourceCoreIds()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printer.printOperand(getOutputBuffer()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getOutputBuffer().getType()); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand outputBuffer; + Type outputBufferType; + Type outputType; + SmallVector sourceCoreIds; + + if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() + || parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) + || parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow() + || parser.parseType(outputType)) + return failure(); + + if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "sourceCoreIds cannot be specified both positionally and in attr-dict"); + if (!sourceCoreIds.empty()) + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + + if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + +void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) { + printCoreIdList(printer, "from", getSourceCoreIds()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printCompressedValueSequence(printer, getOutputBuffers()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputBuffers; + SmallVector outputTypes; + SmallVector sourceCoreIds; + + if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() + || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (outputBuffers.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); + if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "sourceCoreIds cannot be specified both positionally and in attr-dict"); + if (!sourceCoreIds.empty()) + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + + if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + return success(); +} + +void PimExtractRowsOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printCompressedValueSequence(printer, getOutputBuffers()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInput().getType()); + printer << " -> "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + SmallVector outputBuffers; + Type inputType; + SmallVector outputTypes; + + if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen() + || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (outputBuffers.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); + if (parser.resolveOperand(input, inputType, result.operands) + || parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + return success(); +} + +void PimConcatOp::print(OpAsmPrinter& printer) { + printer << " axis " << getAxis() << " "; + printCompressedValueSequence(printer, getInputs()); + printer << " into "; + printer.printOperand(getOutputBuffer()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); + printer << " : "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printCompressedTypeSequence(printer, TypeRange(getInputs())); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult PimConcatOp::parse(OpAsmParser& parser, OperationState& result) { + int64_t axis = 0; + SmallVector inputs; + OpAsmParser::UnresolvedOperand outputBuffer; + SmallVector inputTypes; + Type outputType; + + if (parser.parseKeyword("axis") || parser.parseInteger(axis) || parseCompressedOperandSequence(parser, inputs) + || parser.parseKeyword("into") || parser.parseOperand(outputBuffer) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseLParen() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false) || parser.parseRParen() + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (result.attributes.get("axis")) + return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); + + result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperand(outputBuffer, outputType, result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + +} // namespace pim +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp new file mode 100644 index 0000000..85ed3a8 --- /dev/null +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -0,0 +1,268 @@ +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/TypeUtilities.h" + +#include "llvm/Support/LogicalResult.h" + +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace pim { + +namespace { + +static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef coreIds, size_t valueCount) { + if (coreIds.size() != valueCount) + return op->emitError("core id metadata length must match the number of values"); + return success(); +} + +static bool haveSameShapedContainerKind(Type lhs, Type rhs) { + return (isa(lhs) && isa(rhs)) || (isa(lhs) && isa(rhs)); +} + +static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) { + auto lhsShaped = dyn_cast(lhs); + auto rhsShaped = dyn_cast(rhs); + if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs)) + return op->emitError(message); + if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape()) + return op->emitError(message); + return success(); +} + +static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) { + if (types.empty()) + return op->emitError() << kind << " must carry at least one value"; + + Type firstType = types.front(); + auto firstShapedType = dyn_cast(firstType); + bool firstIsTensor = isa(firstType); + bool firstIsMemRef = isa(firstType); + for (Type type : types.drop_front()) + if (type != firstType) { + auto shapedType = dyn_cast(type); + if (!firstShapedType || !shapedType) + return op->emitError() << kind << " values must all have the same type"; + if (firstIsTensor != isa(type) || firstIsMemRef != isa(type)) + return op->emitError() << kind << " values must all use the same shaped container kind"; + if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape()) + return op->emitError() << kind << " values must all have the same shape and element type"; + } + return success(); +} + +static FailureOr getParentBatchLaneCount(Operation* op) { + auto coreBatchOp = op->getParentOfType(); + if (!coreBatchOp) + return failure(); + return coreBatchOp.getLaneCount(); +} + +static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, + ArrayRef coreIds, + size_t valueCount) { + auto laneCount = getParentBatchLaneCount(op); + if (failed(laneCount)) + return op->emitError("must be nested inside pim.core_batch"); + if (coreIds.size() != valueCount * static_cast(*laneCount)) + return op->emitError("core id metadata length must match the number of values times parent laneCount"); + return success(); +} + +} // namespace + +LogicalResult PimEmptyManyOp::verify() { + if (getOutputs().empty()) + return emitError("must produce at least one output"); + + Type firstType = getOutputs().front().getType(); + auto firstTensorType = dyn_cast(firstType); + if (!firstTensorType) + return emitError("outputs must all be ranked tensor types"); + + for (Value output : getOutputs().drop_front()) + if (output.getType() != firstType) + return emitError("outputs must all have the same type"); + + return success(); +} + +LogicalResult PimMapOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + if (getOutputs().size() != getInputs().size()) + return emitError("number of outputs must match number of inputs"); + + Type inputType = getInputs().front().getType(); + for (Value input : getInputs().drop_front()) + if (input.getType() != inputType) + return emitError("all inputs must have the same type"); + + Type outputType = getOutputs().front().getType(); + for (Value output : getOutputs().drop_front()) + if (output.getType() != outputType) + return emitError("all outputs must have the same type"); + + Block& block = getBody().front(); + if (block.getNumArguments() != 1) + return emitError("body must have exactly one block argument"); + if (block.getArgument(0).getType() != inputType) + return emitError("body block argument type must match input type"); + + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("body must terminate with pim.yield"); + if (yieldOp.getNumOperands() != 1) + return emitError("body yield must produce exactly one value"); + if (yieldOp.getOperand(0).getType() != outputType) + return emitError("body yield type must match output type"); + + return success(); +} + +LogicalResult PimSendManyOp::verify() { + if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many"); +} + +LogicalResult PimSendManyBatchOp::verify() { + if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch"); +} + +LogicalResult PimReceiveManyOp::verify() { + if (getOutputBuffers().size() != getOutputs().size()) + return emitError("number of output buffers must match the number of outputs"); + if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size()))) + return failure(); + + if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many"))) + return failure(); + if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many"))) + return failure(); + + for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) + if (outputBuffer.getType() != output.getType()) + return emitError("output buffers and outputs must have matching types"); + + return success(); +} + +LogicalResult PimReceiveManyBatchOp::verify() { + if (getOutputBuffers().size() != getOutputs().size()) + return emitError("number of output buffers must match the number of outputs"); + if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size()))) + return failure(); + + if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch"))) + return failure(); + if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch"))) + return failure(); + + for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) + if (outputBuffer.getType() != output.getType()) + return emitError("output buffers and outputs must have matching types"); + + return success(); +} + +LogicalResult PimExtractRowsOp::verify() { + if (getOutputBuffers().size() != getOutputs().size()) + return emitError("number of output buffers must match the number of outputs"); + + auto inputType = dyn_cast(getInput().getType()); + if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) + return emitError("input must be a rank-2 shaped type"); + + int64_t numRows = inputType.getShape()[0]; + int64_t numCols = inputType.getShape()[1]; + Type elementType = inputType.getElementType(); + + if (numRows >= 0 && static_cast(getOutputs().size()) != numRows) + return emitError("number of outputs must match the number of input rows"); + + for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) { + if (failed(verifyCompatibleShapedTypes( + getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match"))) + return failure(); + + auto outputType = dyn_cast(output.getType()); + if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) + return emitError("outputs must all be rank-2 shaped types"); + if (!haveSameShapedContainerKind(getInput().getType(), output.getType())) + return emitError("outputs must use the same shaped container kind as the input"); + if (outputType.getElementType() != elementType) + return emitError("output element types must match input element type"); + auto outputShape = outputType.getShape(); + if (outputShape[0] != 1) + return emitError("each output must have exactly one row"); + if (numCols >= 0 && outputShape[1] != numCols) + return emitError("output column count must match input column count"); + } + + return success(); +} + +LogicalResult PimConcatOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + + if (failed(verifyCompatibleShapedTypes( + getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) + return failure(); + + auto outputType = dyn_cast(getOutput().getType()); + if (!outputType || !outputType.hasRank()) + return emitError("output must be a ranked shaped type"); + + int64_t axis = getAxis(); + int64_t rank = outputType.getRank(); + if (axis < 0 || axis >= rank) + return emitError("axis must be within the output rank"); + + int64_t concatenatedDimSize = 0; + bool concatenatedDimDynamic = false; + Type outputElementType = outputType.getElementType(); + + for (Value input : getInputs()) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasRank()) + return emitError("inputs must be ranked shaped types"); + if (!haveSameShapedContainerKind(input.getType(), getOutput().getType())) + return emitError("inputs and output must use the same shaped container kind"); + if (inputType.getRank() != rank) + return emitError("all inputs must have the same rank as the output"); + if (inputType.getElementType() != outputElementType) + return emitError("all inputs must have the same element type as the output"); + + for (int64_t dim = 0; dim < rank; ++dim) { + if (dim == axis) + continue; + int64_t inputDim = inputType.getDimSize(dim); + int64_t outputDim = outputType.getDimSize(dim); + if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) + return emitError("non-concatenated dimensions must match the output shape"); + } + + int64_t inputConcatDim = inputType.getDimSize(axis); + if (ShapedType::isDynamic(inputConcatDim)) { + concatenatedDimDynamic = true; + continue; + } + concatenatedDimSize += inputConcatDim; + } + + int64_t outputConcatDim = outputType.getDimSize(axis); + if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) + return emitError("output concatenated dimension must equal the sum of input sizes"); + + return success(); +} + +} // namespace pim +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 481d0af..f5a11f3 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -173,6 +173,235 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + SmallVector outputBuffers; + SmallVector resultTypes; + SmallVector tensorResults; + outputBuffers.reserve(receiveOp.getOutputBuffers().size()); + resultTypes.reserve(receiveOp.getOutputBuffers().size()); + tensorResults.reserve(receiveOp.getOutputBuffers().size()); + + for (Value outputBuffer : receiveOp.getOutputBuffers()) { + auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); + if (failed(outputBufferOpt)) + return failure(); + outputBuffers.push_back(*outputBufferOpt); + resultTypes.push_back(outputBufferOpt->getType()); + } + + auto newOp = PimReceiveManyOp::create( + rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr()); + + for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) { + auto tensorType = cast(tensorResult.getType()); + auto toTensor = + bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr()); + tensorResults.push_back(toTensor.getResult()); + } + + rewriter.replaceOp(receiveOp, tensorResults); + return success(); + } +}; + +struct ReceiveManyBatchOpInterface +: DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + SmallVector outputBuffers; + SmallVector resultTypes; + SmallVector tensorResults; + outputBuffers.reserve(receiveOp.getOutputBuffers().size()); + resultTypes.reserve(receiveOp.getOutputBuffers().size()); + tensorResults.reserve(receiveOp.getOutputBuffers().size()); + + for (Value outputBuffer : receiveOp.getOutputBuffers()) { + auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); + if (failed(outputBufferOpt)) + return failure(); + outputBuffers.push_back(*outputBufferOpt); + resultTypes.push_back(outputBufferOpt->getType()); + } + + auto newOp = PimReceiveManyBatchOp::create(rewriter, + receiveOp.getLoc(), + TypeRange(resultTypes), + ValueRange(outputBuffers), + receiveOp.getSourceCoreIdsAttr()); + + for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) { + auto tensorType = cast(tensorResult.getType()); + auto toTensor = + bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr()); + tensorResults.push_back(toTensor.getResult()); + } + + rewriter.replaceOp(receiveOp, tensorResults); + return success(); + } +}; + +struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto extractRowsOp = cast(op); + auto inputOpt = getBuffer(rewriter, extractRowsOp.getInput(), options, state); + if (failed(inputOpt)) + return failure(); + + SmallVector outputBuffers; + SmallVector resultTypes; + outputBuffers.reserve(extractRowsOp.getOutputBuffers().size()); + resultTypes.reserve(extractRowsOp.getOutputBuffers().size()); + + for (Value outputBuffer : extractRowsOp.getOutputBuffers()) { + auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); + if (failed(outputBufferOpt)) + return failure(); + outputBuffers.push_back(*outputBufferOpt); + resultTypes.push_back(outputBufferOpt->getType()); + } + + auto newOp = PimExtractRowsOp::create(rewriter, + extractRowsOp.getLoc(), + TypeRange(resultTypes), + materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), + ValueRange(outputBuffers)); + rewriter.replaceOp(extractRowsOp, newOp.getOutputs()); + return success(); + } +}; + +struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto concatOp = cast(op); + SmallVector inputs; + inputs.reserve(concatOp.getInputs().size()); + for (Value input : concatOp.getInputs()) { + auto inputOpt = getBuffer(rewriter, input, options, state); + if (failed(inputOpt)) + return failure(); + inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter)); + } + + auto outputBufferOpt = getBuffer(rewriter, concatOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt); + return success(); + } +}; + +struct MapOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const { + auto mapOp = cast(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty()) + return {}; + + return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}}; + } + + bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; } + + FailureOr + getBufferType(Operation* op, + Value value, + const BufferizationOptions& options, + const BufferizationState& state, + SmallVector& invocationStack) const { + auto mapOp = cast(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty()) + return failure(); + + auto inputType = dyn_cast(mapOp.getInputs().front().getType()); + if (inputType) + return inputType; + + auto shapedType = cast(mapOp.getInputs().front().getType()); + return BufferLikeType(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto mapOp = cast(op); + + SmallVector inputs; + SmallVector resultTypes; + inputs.reserve(mapOp.getInputs().size()); + resultTypes.reserve(mapOp.getOutputs().size()); + + for (Value input : mapOp.getInputs()) { + if (isa(input.getType())) { + auto inputOpt = getBuffer(rewriter, input, options, state); + if (failed(inputOpt)) + return failure(); + inputs.push_back(*inputOpt); + } + else { + inputs.push_back(input); + } + } + + for (Value output : mapOp.getOutputs()) { + auto shapedType = cast(output.getType()); + resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); + } + + rewriter.setInsertionPoint(mapOp); + auto newOp = PimMapOp::create(rewriter, mapOp.getLoc(), TypeRange(resultTypes), ValueRange(inputs)); + rewriter.inlineRegionBefore(mapOp.getBody(), newOp.getBody(), newOp.getBody().begin()); + for (Block& block : newOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) + return failure(); + + rewriter.replaceOp(mapOp, newOp.getOutputs()); + return success(); + } +}; + struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; @@ -435,9 +664,14 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(*ctx); PimCoreBatchOp::attachInterface(*ctx); PimReceiveOp::attachInterface(*ctx); + PimReceiveManyOp::attachInterface(*ctx); PimReceiveBatchOp::attachInterface(*ctx); + PimReceiveManyBatchOp::attachInterface(*ctx); + PimExtractRowsOp::attachInterface(*ctx); + PimConcatOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyHostToDevBatchOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 866453f..c9421c0 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -3,7 +3,9 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Threading.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Casting.h" @@ -45,6 +47,23 @@ private: void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); + { + SmallVector emptyManyOps; + moduleOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { emptyManyOps.push_back(emptyManyOp); }); + + IRRewriter rewriter(moduleOp.getContext()); + for (auto emptyManyOp : emptyManyOps) { + SmallVector replacementValues; + replacementValues.reserve(emptyManyOp.getOutputs().size()); + rewriter.setInsertionPoint(emptyManyOp); + for (Value output : emptyManyOp.getOutputs()) { + auto outputType = cast(output.getType()); + replacementValues.push_back( + tensor::EmptyOp::create(rewriter, emptyManyOp.getLoc(), outputType.getShape(), outputType.getElementType())); + } + rewriter.replaceOp(emptyManyOp, replacementValues); + } + } // Refactor this into a function { auto funcOp = getPimEntryFunc(moduleOp); diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index d9ac3c5..f83fa4e 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -8,6 +8,7 @@ #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -23,23 +24,23 @@ enum class ListDelimiter { }; static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - if (delimiter == ListDelimiter::Square) - return parser.parseLSquare(); - return parser.parseLParen(); + return onnx_mlir::compact_asm::parseOpenDelimiter( + parser, static_cast(delimiter)); } static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - if (delimiter == ListDelimiter::Square) - return parser.parseOptionalRSquare(); - return parser.parseOptionalRParen(); + return onnx_mlir::compact_asm::parseOptionalCloseDelimiter( + parser, static_cast(delimiter)); } static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "[" : "("); + onnx_mlir::compact_asm::printOpenDelimiter( + printer, static_cast(delimiter)); } static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "]" : ")"); + onnx_mlir::compact_asm::printCloseDelimiter( + printer, static_cast(delimiter)); } static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { @@ -51,31 +52,8 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& entries, ParseEntryFn parseEntry) { - if (parseOpenDelimiter(parser, delimiter)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - return success(); - - while (true) { - EntryT entry; - if (parseEntry(entry)) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - entries.push_back(entry); - - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); + return onnx_mlir::compact_asm::parseCompressedRepeatedList( + parser, static_cast(delimiter), entries, parseEntry); } template @@ -388,156 +366,32 @@ static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty); static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { - if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0) - return false; - - SmallVector valueVec(values.begin(), values.end()); - ArrayRef tuple(valueVec.data(), tupleSize); - for (size_t index = tupleSize; index < values.size(); index += tupleSize) - if (!llvm::equal(tuple, ArrayRef(valueVec).slice(index, tupleSize))) - return false; - return true; + return onnx_mlir::compact_asm::hasRepeatedTuple(values, tupleSize); } static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { - if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0) - return false; - - SmallVector typeVec(types.begin(), types.end()); - ArrayRef tuple(typeVec.data(), tupleSize); - for (size_t index = tupleSize; index < types.size(); index += tupleSize) - if (!llvm::equal(tuple, ArrayRef(typeVec).slice(index, tupleSize))) - return false; - return true; + return onnx_mlir::compact_asm::hasRepeatedTuple(types, tupleSize); } static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) { - printer << "["; - printOpenDelimiter(printer, ListDelimiter::Paren); - for (size_t index = 0; index < tupleSize; ++index) { - if (index != 0) - printer << ", "; - printer.printOperand(values[index]); - } - printCloseDelimiter(printer, ListDelimiter::Paren); - printer << " x" << (values.size() / tupleSize) << "]"; + onnx_mlir::compact_asm::printValueTupleRun( + printer, values, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square); } static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) { - printer << "["; - printOpenDelimiter(printer, ListDelimiter::Paren); - for (size_t index = 0; index < tupleSize; ++index) { - if (index != 0) - printer << ", "; - printer.printType(types[index]); - } - printCloseDelimiter(printer, ListDelimiter::Paren); - printer << " x" << (types.size() / tupleSize) << "]"; + onnx_mlir::compact_asm::printTypeTupleRun( + printer, types, tupleSize, onnx_mlir::compact_asm::ListDelimiter::Square); } static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, SmallVectorImpl& operands) { - if (parser.parseLSquare()) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) - return success(); - - if (succeeded(parser.parseOptionalLParen())) { - SmallVector tupleOperands; - if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - llvm::append_range(operands, tupleOperands); - - while (succeeded(parser.parseOptionalComma())) { - if (parser.parseLParen()) - return failure(); - tupleOperands.clear(); - if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) - return failure(); - - repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - llvm::append_range(operands, tupleOperands); - } - return parser.parseRSquare(); - } - - while (true) { - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) - return success(); - if (parser.parseComma()) - return failure(); - } + return onnx_mlir::compact_asm::parseCompressedOrTupleOperandList( + parser, onnx_mlir::compact_asm::ListDelimiter::Square, operands); } static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl& types) { - if (parser.parseLSquare()) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) - return success(); - - if (succeeded(parser.parseOptionalLParen())) { - SmallVector tupleTypes; - if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - llvm::append_range(types, tupleTypes); - - while (succeeded(parser.parseOptionalComma())) { - if (parser.parseLParen()) - return failure(); - tupleTypes.clear(); - if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) - return failure(); - - repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - llvm::append_range(types, tupleTypes); - } - return parser.parseRSquare(); - } - - while (true) { - Type type; - if (parser.parseType(type)) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t repeat = 0; repeat < repeatCount; ++repeat) - types.push_back(type); - - if (succeeded(parser.parseOptionalRSquare())) - return success(); - if (parser.parseComma()) - return failure(); - } + return onnx_mlir::compact_asm::parseCompressedOrTupleTypeList( + parser, onnx_mlir::compact_asm::ListDelimiter::Square, types); } static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index dbf8fd1..ca39577 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -1,9 +1,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" @@ -31,6 +34,35 @@ static int64_t getValueSizeInBytes(Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } +static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector mapOps; + funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); }); + + for (auto mapOp : mapOps) { + Block& body = mapOp.getBody().front(); + auto yieldOp = cast(body.getTerminator()); + + SmallVector replacements; + replacements.reserve(mapOp.getInputs().size()); + rewriter.setInsertionPoint(mapOp); + for (Value input : mapOp.getInputs()) { + IRMapping mapping; + mapping.map(body.getArgument(0), input); + + for (Operation& op : body.without_terminator()) { + Operation* cloned = rewriter.clone(op, mapping); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapping.map(originalResult, clonedResult); + rewriter.setInsertionPointAfter(cloned); + } + + replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0))); + } + + rewriter.replaceOp(mapOp, replacements); + } +} + struct MaterializeHostConstantsPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) @@ -41,13 +73,15 @@ struct MaterializeHostConstantsPass : PassWrapper()) { if (funcOp.isExternal()) continue; + expandPimMapOps(funcOp, rewriter); + for (pim::PimCoreOp coreOp : funcOp.getOps()) { DenseMap>> materializedValues; @@ -113,6 +147,45 @@ struct MaterializeHostConstantsPass : PassWrapper hostCompactOps; + for (Operation& op : funcOp.getBody().front()) + if (isa(op)) + hostCompactOps.push_back(&op); + + for (Operation* op : hostCompactOps) { + rewriter.setInsertionPoint(op); + + if (auto extractRowsOp = dyn_cast(op)) { + auto inputType = dyn_cast(extractRowsOp.getInput().getType()); + if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) { + extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input"); + hasFailure = true; + continue; + } + + int64_t numCols = inputType.getDimSize(1); + SmallVector replacementRows; + replacementRows.reserve(extractRowsOp.getOutputs().size()); + for (auto rowIndex : llvm::seq(0, extractRowsOp.getOutputs().size())) { + SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), + rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + replacementRows.push_back(memref::SubViewOp::create( + rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides) + .getResult()); + } + + extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows)); + extractRowsOp->erase(); + continue; + } + + auto concatOp = cast(op); + concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization"); + hasFailure = true; + } } if (hasFailure) {