From 5b9bb0c191d59bbd8d7b3bd5b8d37666802ad470 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 4 May 2026 14:19:30 +0200 Subject: [PATCH] refactor spatial ops --- src/PIM/Dialect/Spatial/CMakeLists.txt | 3 + src/PIM/Dialect/Spatial/SpatialOps.cpp | 1390 ----------------- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 912 +++++++++++ .../Spatial/SpatialOpsCanonicalization.cpp | 35 + src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 433 +++++ 5 files changed, 1383 insertions(+), 1390 deletions(-) create mode 100644 src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp create mode 100644 src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp create mode 100644 src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 473e8ec..50f5ce0 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -4,6 +4,9 @@ add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps Channels.cpp SpatialOps.cpp + SpatialOpsAsm.cpp + SpatialOpsVerify.cpp + SpatialOpsCanonicalization.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index a1cb7aa..aae95d6 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -1,29 +1,3 @@ -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" - -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include - -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -43,1370 +17,6 @@ void SpatialDialect::initialize() { >(); } - -namespace { - -enum class ListDelimiter { - Square, - Paren -}; - -static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - if (delimiter == ListDelimiter::Square) - return parser.parseLSquare(); - return parser.parseLParen(); -} - -static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { - if (delimiter == ListDelimiter::Square) - return parser.parseOptionalRSquare(); - return parser.parseOptionalRParen(); -} - -static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "[" : "("); -} - -static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { - printer << (delimiter == ListDelimiter::Square ? "]" : ")"); -} - -template -static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& entries, - ParseEntryFn parseEntry) { - if (parseOpenDelimiter(parser, delimiter)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - return success(); - - while (true) { - EntryT entry; - if (parseEntry(entry)) - return failure(); - - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - entries.push_back(entry); - - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); -} - -template -static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { - if (parser.parseLSquare()) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) - return success(); - - while (true) { - int64_t first = 0; - if (parser.parseInteger(first)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("to"))) { - int64_t last = 0; - if (parser.parseInteger(last) || last < first) - return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); - - int64_t step = 1; - if (succeeded(parser.parseOptionalKeyword("by"))) { - if (parser.parseInteger(step) || step <= 0) - return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); - } - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - if ((last - first) % step != 0) - return parser.emitError(parser.getCurrentLocation(), "range end must be reachable from start using the given step"); - - for (int64_t value = first; value <= last; value += step) - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(value)); - } - else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(first)); - } - - if (succeeded(parser.parseOptionalRSquare())) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); -} - -template -static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { - for (size_t index = 0; index < entries.size();) { - size_t runEnd = index + 1; - while (runEnd < entries.size() && entries[runEnd] == entries[index]) - ++runEnd; - - if (index != 0) - printer << ", "; - printEntry(entries[index]); - size_t runLength = runEnd - index; - if (runLength > 1) - printer << " x" << runLength; - index = runEnd; - } -} - -template -static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { - printer << "["; - for (size_t index = 0; index < values.size();) { - if (index != 0) - printer << ", "; - - auto findEqualRunEnd = [&](size_t start) { - size_t end = start + 1; - while (end < values.size() && values[end] == values[start]) - ++end; - return end; - }; - - size_t firstRunEnd = findEqualRunEnd(index); - size_t repeatCount = firstRunEnd - index; - size_t progressionEnd = firstRunEnd; - int64_t step = 0; - IntT lastValue = values[index]; - - if (firstRunEnd < values.size()) { - size_t secondRunEnd = findEqualRunEnd(firstRunEnd); - step = static_cast(values[firstRunEnd]) - static_cast(values[index]); - if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { - progressionEnd = secondRunEnd; - lastValue = values[firstRunEnd]; - size_t currentRunStart = secondRunEnd; - while (currentRunStart < values.size()) { - size_t currentRunEnd = findEqualRunEnd(currentRunStart); - if (currentRunEnd - currentRunStart != repeatCount) - break; - if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) - break; - lastValue = values[currentRunStart]; - progressionEnd = currentRunEnd; - currentRunStart = currentRunEnd; - } - } - else { - step = 0; - } - } - - size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; - if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { - printer << values[index] << " to " << lastValue; - if (step != 1) - printer << " by " << step; - if (repeatCount > 1) - printer << " x" << repeatCount; - index = progressionEnd; - continue; - } - - if (repeatCount > 1) { - printer << values[index] << " x" << repeatCount; - index = firstRunEnd; - continue; - } - - printer << values[index]; - index = firstRunEnd; - } - printer << "]"; -} - -static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { - printOpenDelimiter(printer, delimiter); - for (size_t index = 0; index < values.size();) { - size_t equalRunEnd = index + 1; - while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) - ++equalRunEnd; - - if (index != 0) - printer << ", "; - if (equalRunEnd - index > 1) { - printer.printOperand(values[index]); - printer << " x" << (equalRunEnd - index); - index = equalRunEnd; - continue; - } - - size_t rangeEnd = index + 1; - if (auto firstResult = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextResult = dyn_cast(values[rangeEnd]); - if (!nextResult || nextResult.getOwner() != firstResult.getOwner() - || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - else if (auto firstArg = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextArg = dyn_cast(values[rangeEnd]); - if (!nextArg || nextArg.getOwner() != firstArg.getOwner() - || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - - printer.printOperand(values[index]); - if (rangeEnd - index >= 3) { - printer << " to "; - printer.printOperand(values[rangeEnd - 1]); - } - else if (rangeEnd - index == 2) { - printer << ", "; - printer.printOperand(values[index + 1]); - } - index = rangeEnd; - } - printCloseDelimiter(printer, delimiter); -} - -static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { - printOpenDelimiter(printer, delimiter); - printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); - printCloseDelimiter(printer, delimiter); -} - -static ParseResult parseCompressedOperandEntryWithFirst( - OpAsmParser& parser, - OpAsmParser::UnresolvedOperand firstOperand, - SmallVectorImpl& operands) { - if (succeeded(parser.parseOptionalKeyword("to"))) { - OpAsmParser::UnresolvedOperand lastOperand; - if (parser.parseOperand(lastOperand)) - return failure(); - if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) - return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); - for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) - operands.push_back({firstOperand.location, firstOperand.name, number}); - } - else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - operands.push_back(firstOperand); - } - - return success(); -} - -static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, - SmallVectorImpl& operands) { - OpAsmParser::UnresolvedOperand firstOperand; - if (parser.parseOperand(firstOperand)) - return failure(); - return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); -} - -static ParseResult parseCompressedOperandList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& operands) { - if (parseOpenDelimiter(parser, delimiter)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - return success(); - - while (true) { - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) - break; - if (parser.parseComma()) - return failure(); - } - - return success(); -} - -static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, - SmallVectorImpl& operands) { - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - while (succeeded(parser.parseOptionalComma())) - if (parseOneCompressedOperandEntry(parser, operands)) - return failure(); - return success(); -} - -static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { - for (size_t index = 0; index < values.size();) { - size_t equalRunEnd = index + 1; - while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) - ++equalRunEnd; - - if (index != 0) - printer << ", "; - if (equalRunEnd - index > 1) { - printer.printOperand(values[index]); - printer << " x" << (equalRunEnd - index); - index = equalRunEnd; - continue; - } - - size_t rangeEnd = index + 1; - if (auto firstResult = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextResult = dyn_cast(values[rangeEnd]); - if (!nextResult || nextResult.getOwner() != firstResult.getOwner() - || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - else if (auto firstArg = dyn_cast(values[index])) { - while (rangeEnd < values.size()) { - auto nextArg = dyn_cast(values[rangeEnd]); - if (!nextArg || nextArg.getOwner() != firstArg.getOwner() - || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) - break; - ++rangeEnd; - } - } - - printer.printOperand(values[index]); - if (rangeEnd - index >= 3) { - printer << " to "; - printer.printOperand(values[rangeEnd - 1]); - } - else if (rangeEnd - index == 2) { - printer << ", "; - printer.printOperand(values[index + 1]); - } - index = rangeEnd; - } -} - -static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { - printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); -} - -static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { - Type firstType; - OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); - if (!firstTypeResult.has_value()) { - if (allowEmpty) - return success(); - return parser.emitError(parser.getCurrentLocation(), "expected type"); - } - if (failed(*firstTypeResult)) - return failure(); - - auto appendType = [&](Type type) -> ParseResult { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); - } - for (int64_t index = 0; index < repeatCount; ++index) - types.push_back(type); - return success(); - }; - - if (appendType(firstType)) - return failure(); - - while (succeeded(parser.parseOptionalComma())) { - Type nextType; - if (parser.parseType(nextType) || appendType(nextType)) - return failure(); - } - - return success(); -} - -static void printChannelMetadata(OpAsmPrinter& printer, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - printer << " channels "; - printCompressedIntegerList(printer, channelIds); - printer << " from "; - printCompressedIntegerList(printer, sourceCoreIds); - printer << " to "; - printCompressedIntegerList(printer, targetCoreIds); -} - -static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef values) { - return parser.getBuilder().getDenseI64ArrayAttr(values); -} - -static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { - return parser.getBuilder().getDenseI32ArrayAttr(values); -} - -static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { - return parser.getBuilder().getI32IntegerAttr(value); -} - -static void buildImplicitRegionArgs(OpAsmParser& parser, - ArrayRef inputTypes, - SmallVectorImpl& generatedNames, - SmallVectorImpl& arguments) { - generatedNames.reserve(inputTypes.size()); - arguments.reserve(inputTypes.size()); - for (auto [index, inputType] : llvm::enumerate(inputTypes)) { - generatedNames.push_back("arg" + std::to_string(index + 1)); - OpAsmParser::Argument arg; - arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; - arg.type = inputType; - arguments.push_back(arg); - } -} - -} // namespace - -inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, - ArrayRef& matrixShape, - ArrayRef& vectorShape, - ArrayRef& outputShape) { - - // Verify that the matrix, vector and output shapes have rank 2 - if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) - return emitter->emitError("matrix, vector and output must have rank 2"); - - // Verify that the matrix shape is (N, M) - int64_t N = matrixShape[0]; - int64_t M = matrixShape[1]; - if (N <= 0 || M <= 0) - return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); - - // Verify that the vector shape is (M, 1) - int64_t vectorM = vectorShape[0]; - int64_t vector1 = vectorShape[1]; - if (vectorM != M || vector1 != 1) - return emitter->emitError("vector shape must be (M, 1)"); - - // Verify that the output shape is (N, 1) - int64_t outputN = outputShape[0]; - int64_t output1 = outputShape[1]; - if (outputN != N || output1 != 1) - return emitter->emitError("output shape must be (N, 1)"); - - return success(); -} - -inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, - ArrayRef& matrixShape, - ArrayRef& vectorShape, - ArrayRef& outputShape) { - - // Verify that the matrix, vector and output shapes have rank 4 - if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) - return emitter->emitError("matrix, vector and output must have rank 4"); - - // Verify that the matrix shape is (N, M, 1, 1) - int64_t N = matrixShape[0]; - int64_t M = matrixShape[1]; - int64_t matrix1First = matrixShape[2]; - int64_t matrix1Second = matrixShape[3]; - if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) - return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); - - // Verify that the vector shape is (1, M, 1, 1) - int64_t vector1First = vectorShape[0]; - int64_t vectorM = vectorShape[1]; - int64_t vector1Second = vectorShape[2]; - int64_t vector1Third = vectorShape[3]; - if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { - if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { - // This is ok, it was caused by the simplification of the concat error - } - else { - return emitter->emitError("vector shape must be (1, M, 1, 1)"); - } - } - - // Verify that the output shape is (1, N, 1, 1) - int64_t output1First = outputShape[0]; - int64_t outputN = outputShape[1]; - int64_t output1Second = outputShape[2]; - int64_t output1Third = outputShape[3]; - if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) - return emitter->emitError("output shape must be (1, N, 1, 1)"); - - return success(); -} - -llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { - if (auto computeOp = dyn_cast(weigthedOp->getParentOp())) - return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); - - if (auto coreOp = dyn_cast(weigthedOp->getParentOp())) - return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); - - // In compute_batch bodies, weightIndex refers to the lane-local template - // weight index, so lane 0's weight slice is representative for type checks. - if (auto batchOp = dyn_cast(weigthedOp->getParentOp())) { - if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) - return failure(); - return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); - } - - return failure(); -} - -LogicalResult SpatWeightedMVMOp::verify() { - auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); - if (failed(matrixShapeOpt)) - return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op"); - auto matrixShape = *matrixShapeOpt; - auto vectorShape = getInput().getType().getShape(); - auto outputShape = getOutput().getType().getShape(); - - /* Two possible accepted shapes: - 1. matrix: (N, M); vector: (M, 1); output: (N, 1) - 2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1) - */ - - if (matrixShape.size() == 2) - return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); - else if (matrixShape.size() == 4) - return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); - else - return emitError("matrix rank must be 2 or 4"); -} - -LogicalResult SpatWeightedVMMOp::verify() { - auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); - if (failed(matrixShapeOpt)) - return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op"); - auto matrixShape = *matrixShapeOpt; - auto vectorShape = getInput().getType().getShape(); - auto outputShape = getOutput().getType().getShape(); - - /* Accepted shape: - 1. vector: (1, N); matrix: (N, M); output: (1, M) - */ - if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) - return emitError("matrix, vector and output must have rank 2"); - - int64_t N = matrixShape[0]; - int64_t M = matrixShape[1]; - if (N <= 0 || M <= 0) - return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); - - int64_t vector1 = vectorShape[0]; - int64_t vectorN = vectorShape[1]; - if (vectorN != N || vector1 != 1) - return emitError("vector shape must be (N, 1)"); - - int64_t output1 = outputShape[0]; - int64_t outputM = outputShape[1]; - if (outputM != M || output1 != 1) - return emitError("output shape must be (M, 1)"); - - return success(); -} - -LogicalResult SpatVAddOp::verify() { - // At least two operands - if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) - return failure(); - - return OpTrait::impl::verifySameOperandsAndResultType(*this); -} - -LogicalResult SpatVMaxOp::verify() { - // At least two operands - if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) - return failure(); - - return OpTrait::impl::verifySameOperandsAndResultType(*this); -} - -void SpatYieldOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getOutputs()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printCompressedTypeSequence(printer, getOutputs().getTypes()); -} - -ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputs; - SmallVector outputTypes; - - OpAsmParser::UnresolvedOperand firstOutput; - OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); - if (firstOutputResult.has_value()) { - if (failed(*firstOutputResult)) - return failure(); - if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) - return failure(); - while (succeeded(parser.parseOptionalComma())) - if (parseOneCompressedOperandEntry(parser, outputs)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (outputs.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); - - return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); -} - -LogicalResult SpatExtractRowsOp::verify() { - auto inputType = dyn_cast(getInput().getType()); - if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) - return emitError("input must be a rank-2 shaped type"); - - int64_t numRows = inputType.getShape()[0]; - int64_t numCols = inputType.getShape()[1]; - Type elementType = inputType.getElementType(); - - if (numRows >= 0 && static_cast(getNumResults()) != numRows) - return emitError("number of outputs must match the number of input rows"); - - for (Type output : getResultTypes()) { - auto outputType = dyn_cast(output); - if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) - return emitError("outputs must all be rank-2 shaped types"); - if (outputType.getElementType() != elementType) - return emitError("output element types must match input element type"); - auto outputShape = outputType.getShape(); - if (outputShape[0] != 1) - return emitError("each output must have exactly one row"); - if (numCols >= 0 && outputShape[1] != numCols) - return emitError("output column count must match input column count"); - } - - return success(); -} - -void SpatExtractRowsOp::print(OpAsmPrinter& printer) { - printer << " "; - printer.printOperand(getInput()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printer.printType(getInput().getType()); - printer << " -> "; - printCompressedTypeSequence(printer, getResultTypes()); -} - -ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { - OpAsmParser::UnresolvedOperand input; - Type inputType; - SmallVector outputTypes; - - if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parser.parseType(inputType) || parser.parseArrow() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (parser.resolveOperand(input, inputType, result.operands)) - return failure(); - result.addTypes(outputTypes); - return success(); -} - -LogicalResult SpatConcatOp::verify() { - if (getInputs().empty()) - return emitError("requires at least one input"); - - auto outputType = dyn_cast(getOutput().getType()); - if (!outputType || !outputType.hasRank()) - return emitError("output must be a ranked shaped type"); - - int64_t axis = getAxis(); - int64_t rank = outputType.getRank(); - if (axis < 0 || axis >= rank) - return emitError("axis must be within the output rank"); - - int64_t concatenatedDimSize = 0; - bool concatenatedDimDynamic = false; - Type outputElementType = outputType.getElementType(); - - for (Value input : getInputs()) { - auto inputType = dyn_cast(input.getType()); - if (!inputType || !inputType.hasRank()) - return emitError("inputs must be ranked shaped types"); - if (inputType.getRank() != rank) - return emitError("all inputs must have the same rank as the output"); - if (inputType.getElementType() != outputElementType) - return emitError("all inputs must have the same element type as the output"); - - for (int64_t dim = 0; dim < rank; ++dim) { - if (dim == axis) - continue; - int64_t inputDim = inputType.getDimSize(dim); - int64_t outputDim = outputType.getDimSize(dim); - if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) - return emitError("non-concatenated dimensions must match the output shape"); - } - - int64_t inputConcatDim = inputType.getDimSize(axis); - if (ShapedType::isDynamic(inputConcatDim)) { - concatenatedDimDynamic = true; - continue; - } - concatenatedDimSize += inputConcatDim; - } - - int64_t outputConcatDim = outputType.getDimSize(axis); - if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) - return emitError("output concatenated dimension must equal the sum of input sizes"); - - return success(); -} - -void SpatConcatOp::print(OpAsmPrinter& printer) { - printer << " axis " << getAxis(); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); - printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); - printer << " : "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); - printer << " -> "; - printer.printType(getOutput().getType()); -} - -ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { - int64_t axis = 0; - SmallVector inputs; - SmallVector inputTypes; - Type outputType; - - if (parser.parseKeyword("axis") || parser.parseInteger(axis)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) - || parser.parseArrow() || parser.parseType(outputType)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (result.attributes.get("axis")) - return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); - - result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); - if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputType); - return success(); -} - -LogicalResult SpatCompute::verify() { - // Check that the terminator yields the same number and types as the compute results. - auto& block = getBody().front(); - if (block.mightHaveTerminator()) { - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return emitError("ComputeOp must have a single yield operation"); - - auto resultTypes = getResultTypes(); - auto yieldTypes = yieldOp->getOperandTypes(); - if (resultTypes.size() != yieldTypes.size()) { - return emitError("ComputeOp must have same number of results as yieldOp " - "operands"); - } - - for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { - auto resultType = std::get<0>(it); - auto yieldType = std::get<1>(it); - - // Same type and compatible shape - if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { - return emitError("ComputeOp output must be of the same type as yieldOp " - "operand"); - } - - // Same encoding - if (auto resultRankedType = dyn_cast(resultType)) { - if (auto yieldRankedType = dyn_cast(yieldType)) { - if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { - return emitError("ComputeOp output must have the same encoding as " - "yieldOp operand"); - } - } - else { - return emitError("ComputeOp output has an encoding while yieldOp " - "operand does not have one"); - } - } - else { - // If result does not have an encoding, yield shouldn't either - if (auto yieldRankedType = dyn_cast(yieldType)) { - return emitError("ComputeOp output must not have an encoding if " - "yieldOp operand has one"); - } - } - } - } - - // Check that each block argument is used - for (auto arg : block.getArguments()) - if (arg.use_empty()) - return emitError("ComputeOp block argument is not used"); - - return success(); -} - -LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { - Block& block = getBody().front(); - if (!llvm::hasSingleElement(block)) - return failure(); - - auto yieldOp = dyn_cast(block.front()); - if (!yieldOp) - return failure(); - - for (Value yieldedValue : yieldOp.getOperands()) { - if (auto blockArg = dyn_cast(yieldedValue)) { - if (blockArg.getOwner() == &block) { - results.push_back(getOperand(blockArg.getArgNumber())); - continue; - } - } - results.push_back(yieldedValue); - } - return success(); -} - -void SpatCompute::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); - - if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { - printer << " core_id " << coreIdAttr.getInt(); - } - - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); - - printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); - printer << " "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); - printer << " -> "; - printCompressedTypeSequence(printer, getResultTypes()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); -} - -ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { - SmallVector regionArgs; - SmallVector generatedArgNames; - SmallVector weights; - SmallVector inputs; - SmallVector weightTypes; - SmallVector inputTypes; - SmallVector outputTypes; - int32_t coreId = 0; - - if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { - return failure(); - } - - bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); - if (hasCoreId && parser.parseInteger(coreId)) - return failure(); - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) - || parseCompressedRepeatedList( - parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) - || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (weights.size() != weightTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) - return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both positionally and in attr-dict"); - - auto& builder = parser.getBuilder(); - result.addAttribute( - "operandSegmentSizes", - builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); - if (hasCoreId) - result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); - - if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) - || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); - return parser.parseRegion(*body, regionArgs); -} - -static FailureOr getParentBatchLaneCount(Operation* op) { - auto batchOp = op->getParentOfType(); - if (!batchOp) - return failure(); - return batchOp.getLaneCount(); -} - -static LogicalResult verifyManyChannelSizes(Operation* op, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - size_t valueCount) { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) - return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); - if (channelIds.size() != valueCount) - return op->emitError("channel metadata length must match the number of values"); - return success(); -} - -static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { - if (types.empty()) - return op->emitError() << kind << " must carry at least one value"; - - Type firstType = types.front(); - for (Type type : types.drop_front()) - if (type != firstType) - return op->emitError() << kind << " values must all have the same type"; - return success(); -} - -static LogicalResult verifyBatchChannelSizes(Operation* op, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) - return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); - - auto laneCount = getParentBatchLaneCount(op); - if (failed(laneCount)) - return op->emitError("must be nested inside spat.compute_batch"); - if (channelIds.size() != static_cast(*laneCount)) - return op->emitError("channel metadata length must match parent laneCount"); - - return success(); -} - -LogicalResult SpatChannelSendManyOp::verify() { - if (failed(verifyManyChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); -} - -LogicalResult SpatChannelReceiveManyOp::verify() { - if (failed(verifyManyChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); -} - -LogicalResult SpatChannelSendBatchOp::verify() { - return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); -} - -LogicalResult SpatChannelReceiveBatchOp::verify() { - return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); -} - -static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return op->emitError("body must terminate with spat.yield"); - if (outputTypes.empty()) { - if (yieldOp.getNumOperands() != 0) - return op->emitError("body yield must be empty when compute_batch has no results"); - } - else { - if (yieldOp.getNumOperands() != 1) - return op->emitError("body yield must produce exactly one value"); - if (yieldOp.getOperand(0).getType() != outputTypes[0]) - return op->emitError("body yield type must match output type"); - } - - for (auto& bodyOp : block) { - if (auto wvmm = dyn_cast(&bodyOp)) - if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) - return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); - if (auto wmvm = dyn_cast(&bodyOp)) - if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) - return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); - } - return success(); -} - -LogicalResult SpatComputeBatch::verify() { - int32_t count = getLaneCount(); - if (count <= 0) - return emitError("laneCount must be positive"); - - auto laneCountSz = static_cast(count); - if (getWeights().size() % laneCountSz != 0) - return emitError("number of weights must be a multiple of laneCount"); - - if (!getInputs().empty() && getInputs().size() != laneCountSz) - return emitError("number of inputs must be either 0 or laneCount"); - if (!getOutputs().empty() && getOutputs().size() != laneCountSz) - return emitError("number of outputs must be either 0 or laneCount"); - - size_t weightsPerLane = getWeights().size() / laneCountSz; - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) { - Type weightType = getWeights()[weightIndex].getType(); - for (size_t lane = 1; lane < laneCountSz; ++lane) - if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType) - return emitError("corresponding weights across lanes must have the same type"); - } - - if (!getInputs().empty()) { - Type inputType = getInputs()[0].getType(); - for (Value in : getInputs().drop_front()) - if (in.getType() != inputType) - return emitError("all inputs must have the same type"); - } - - if (!getOutputs().empty()) { - Type outputType = getOutputs()[0].getType(); - for (Value out : getOutputs().drop_front()) - if (out.getType() != outputType) - return emitError("all outputs must have the same type"); - } - - if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) { - auto coreIdsAttr = dyn_cast(coreIdAttr); - if (!coreIdsAttr) - return emitError("compute_batch core_id attribute must be a dense i32 array"); - if (coreIdsAttr.size() != laneCountSz) - return emitError("compute_batch core_id array length must match laneCount"); - if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) - return emitError("compute_batch core_id values must be positive"); - } - - Block& block = getBody().front(); - if (getInputs().empty()) { - if (block.getNumArguments() != 0) - return emitError("compute_batch body must have no block arguments when there are no inputs"); - } - else { - if (block.getNumArguments() != 1) - return emitError("compute_batch body must have exactly one block argument"); - if (block.getArgument(0).getType() != getInputs()[0].getType()) - return emitError("body block argument type must match input type"); - } - - return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); -} - -void SpatComputeBatch::print(OpAsmPrinter& printer) { - printer << " lanes " << getLaneCount() << " "; - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); - - if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { - printer << " core_ids "; - printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); - } - - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); - - printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); - printer << " "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); - printer << " -> "; - printCompressedTypeSequence(printer, getResultTypes()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); -} - -ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { - int32_t laneCount = 0; - SmallVector regionArgs; - SmallVector generatedArgNames; - SmallVector weights; - SmallVector inputs; - SmallVector weightTypes; - SmallVector inputTypes; - SmallVector outputTypes; - SmallVector coreIds; - - if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) - return failure(); - - if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { - return failure(); - } - - bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); - if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) - return failure(); - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) - || parseCompressedRepeatedList( - parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) - || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (weights.size() != weightTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) - return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); - - auto& builder = parser.getBuilder(); - result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); - result.addAttribute( - "operandSegmentSizes", - builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); - if (hasCoreIds) - result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); - - if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) - || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); - return parser.parseRegion(*body, regionArgs); -} - -void SpatChannelSendManyOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getInputs()); - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, TypeRange(getInputs())); -} - -ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector inputs; - SmallVector inputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parseCompressedOperandSequence(parser, inputs)) - return failure(); - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); -} - -void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) { - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, getResultTypes()); -} - -ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - result.addTypes(outputTypes); - return success(); -} - -void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { - printer << " "; - printer.printOperand(getInput()); - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(getInput().getType()); -} - -ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { - OpAsmParser::UnresolvedOperand input; - Type inputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parser.parseOperand(input)) - return failure(); - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperand(input, inputType, result.operands); -} - -void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(getOutput().getType()); -} - -ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { - Type outputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - result.addTypes(outputType); - return success(); -} - } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp new file mode 100644 index 0000000..022b623 --- /dev/null +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -0,0 +1,912 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Support/LogicalResult.h" + +#include + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +namespace { + +enum class ListDelimiter { + Square, + Paren +}; + +static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseLSquare(); + return parser.parseLParen(); +} + +static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseOptionalRSquare(); + return parser.parseOptionalRParen(); +} + +static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "[" : "("); +} + +static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "]" : ")"); +} + +template +static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& entries, + ParseEntryFn parseEntry) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + EntryT entry; + if (parseEntry(entry)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + entries.push_back(entry); + + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +template +static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + while (true) { + int64_t first = 0; + if (parser.parseInteger(first)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + int64_t last = 0; + if (parser.parseInteger(last) || last < first) + return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); + + int64_t step = 1; + if (succeeded(parser.parseOptionalKeyword("by"))) { + if (parser.parseInteger(step) || step <= 0) + return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); + } + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + if ((last - first) % step != 0) + return parser.emitError(parser.getCurrentLocation(), + "range end must be reachable from start using the given step"); + + for (int64_t value = first; value <= last; value += step) + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(value)); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(first)); + } + + if (succeeded(parser.parseOptionalRSquare())) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +template +static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { + for (size_t index = 0; index < entries.size();) { + size_t runEnd = index + 1; + while (runEnd < entries.size() && entries[runEnd] == entries[index]) + ++runEnd; + + if (index != 0) + printer << ", "; + printEntry(entries[index]); + size_t runLength = runEnd - index; + if (runLength > 1) + printer << " x" << runLength; + index = runEnd; + } +} + +template +static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { + printer << "["; + for (size_t index = 0; index < values.size();) { + if (index != 0) + printer << ", "; + + auto findEqualRunEnd = [&](size_t start) { + size_t end = start + 1; + while (end < values.size() && values[end] == values[start]) + ++end; + return end; + }; + + size_t firstRunEnd = findEqualRunEnd(index); + size_t repeatCount = firstRunEnd - index; + size_t progressionEnd = firstRunEnd; + int64_t step = 0; + IntT lastValue = values[index]; + + if (firstRunEnd < values.size()) { + size_t secondRunEnd = findEqualRunEnd(firstRunEnd); + step = static_cast(values[firstRunEnd]) - static_cast(values[index]); + if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { + progressionEnd = secondRunEnd; + lastValue = values[firstRunEnd]; + size_t currentRunStart = secondRunEnd; + while (currentRunStart < values.size()) { + size_t currentRunEnd = findEqualRunEnd(currentRunStart); + if (currentRunEnd - currentRunStart != repeatCount) + break; + if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) + break; + lastValue = values[currentRunStart]; + progressionEnd = currentRunEnd; + currentRunStart = currentRunEnd; + } + } + else { + step = 0; + } + } + + size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; + if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { + printer << values[index] << " to " << lastValue; + if (step != 1) + printer << " by " << step; + if (repeatCount > 1) + printer << " x" << repeatCount; + index = progressionEnd; + continue; + } + + if (repeatCount > 1) { + printer << values[index] << " x" << repeatCount; + index = firstRunEnd; + continue; + } + + printer << values[index]; + index = firstRunEnd; + } + printer << "]"; +} + +static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + for (size_t index = 0; index < values.size();) { + size_t equalRunEnd = index + 1; + while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) + ++equalRunEnd; + + if (index != 0) + printer << ", "; + if (equalRunEnd - index > 1) { + printer.printOperand(values[index]); + printer << " x" << (equalRunEnd - index); + index = equalRunEnd; + continue; + } + + size_t rangeEnd = index + 1; + if (auto firstResult = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextResult = dyn_cast(values[rangeEnd]); + if (!nextResult || nextResult.getOwner() != firstResult.getOwner() + || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + else if (auto firstArg = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextArg = dyn_cast(values[rangeEnd]); + if (!nextArg || nextArg.getOwner() != firstArg.getOwner() + || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + + printer.printOperand(values[index]); + if (rangeEnd - index >= 3) { + printer << " to "; + printer.printOperand(values[rangeEnd - 1]); + } + else if (rangeEnd - index == 2) { + printer << ", "; + printer.printOperand(values[index + 1]); + } + index = rangeEnd; + } + printCloseDelimiter(printer, delimiter); +} + +static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); + printCloseDelimiter(printer, delimiter); +} + +static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, + OpAsmParser::UnresolvedOperand firstOperand, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::UnresolvedOperand lastOperand; + if (parser.parseOperand(lastOperand)) + return failure(); + if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) + return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); + for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) + operands.push_back({firstOperand.location, firstOperand.name, number}); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + operands.push_back(firstOperand); + } + + return success(); +} + +static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, + SmallVectorImpl& operands) { + OpAsmParser::UnresolvedOperand firstOperand; + if (parser.parseOperand(firstOperand)) + return failure(); + return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); +} + +static ParseResult parseCompressedOperandList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, + SmallVectorImpl& operands) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + return success(); +} + +static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { + for (size_t index = 0; index < values.size();) { + size_t equalRunEnd = index + 1; + while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) + ++equalRunEnd; + + if (index != 0) + printer << ", "; + if (equalRunEnd - index > 1) { + printer.printOperand(values[index]); + printer << " x" << (equalRunEnd - index); + index = equalRunEnd; + continue; + } + + size_t rangeEnd = index + 1; + if (auto firstResult = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextResult = dyn_cast(values[rangeEnd]); + if (!nextResult || nextResult.getOwner() != firstResult.getOwner() + || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + else if (auto firstArg = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextArg = dyn_cast(values[rangeEnd]); + if (!nextArg || nextArg.getOwner() != firstArg.getOwner() + || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + + printer.printOperand(values[index]); + if (rangeEnd - index >= 3) { + printer << " to "; + printer.printOperand(values[rangeEnd - 1]); + } + else if (rangeEnd - index == 2) { + printer << ", "; + printer.printOperand(values[index + 1]); + } + index = rangeEnd; + } +} + +static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { + printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); +} + +static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { + Type firstType; + OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); + if (!firstTypeResult.has_value()) { + if (allowEmpty) + return success(); + return parser.emitError(parser.getCurrentLocation(), "expected type"); + } + if (failed(*firstTypeResult)) + return failure(); + + auto appendType = [&](Type type) -> ParseResult { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + types.push_back(type); + return success(); + }; + + if (appendType(firstType)) + return failure(); + + while (succeeded(parser.parseOptionalComma())) { + Type nextType; + if (parser.parseType(nextType) || appendType(nextType)) + return failure(); + } + + return success(); +} + +static void printChannelMetadata(OpAsmPrinter& printer, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + printer << " channels "; + printCompressedIntegerList(printer, channelIds); + printer << " from "; + printCompressedIntegerList(printer, sourceCoreIds); + printer << " to "; + printCompressedIntegerList(printer, targetCoreIds); +} + +static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef values) { + return parser.getBuilder().getDenseI64ArrayAttr(values); +} + +static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { + return parser.getBuilder().getDenseI32ArrayAttr(values); +} + +static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { + return parser.getBuilder().getI32IntegerAttr(value); +} + +static void buildImplicitRegionArgs(OpAsmParser& parser, + ArrayRef inputTypes, + SmallVectorImpl& generatedNames, + SmallVectorImpl& arguments) { + generatedNames.reserve(inputTypes.size()); + arguments.reserve(inputTypes.size()); + for (auto [index, inputType] : llvm::enumerate(inputTypes)) { + generatedNames.push_back("arg" + std::to_string(index + 1)); + OpAsmParser::Argument arg; + arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; + arg.type = inputType; + arguments.push_back(arg); + } +} + +} // namespace + +void SpatYieldOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getOutputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputs; + SmallVector outputTypes; + + OpAsmParser::UnresolvedOperand firstOutput; + OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); + if (firstOutputResult.has_value()) { + if (failed(*firstOutputResult)) + return failure(); + if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, outputs)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (outputs.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); + + return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); +} + +void SpatExtractRowsOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInput().getType()); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector outputTypes; + + if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parser.parseType(inputType) || parser.parseArrow() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes(outputTypes); + return success(); +} + +void SpatConcatOp::print(OpAsmPrinter& printer) { + printer << " axis " << getAxis(); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); + printer << " : "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { + int64_t axis = 0; + SmallVector inputs; + SmallVector inputTypes; + Type outputType; + + if (parser.parseKeyword("axis") || parser.parseInteger(axis)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (result.attributes.get("axis")) + return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); + + result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + +void SpatCompute::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + + if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + printer << " core_id " << coreIdAttr.getInt(); + + printer.printOptionalAttrDict((*this)->getAttrs(), + {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + + printer << " : "; + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { + SmallVector regionArgs; + SmallVector generatedArgNames; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + int32_t coreId = 0; + + if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); + if (hasCoreId && parser.parseInteger(coreId)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "core_id cannot be specified both positionally and in attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreId) + result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + return parser.parseRegion(*body, regionArgs); +} + +void SpatComputeBatch::print(OpAsmPrinter& printer) { + printer << " lanes " << getLaneCount() << " "; + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { + printer << " core_ids "; + printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); + } + + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + + printer << " : "; + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { + int32_t laneCount = 0; + SmallVector regionArgs; + SmallVector generatedArgNames; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + SmallVector coreIds; + + if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) + return failure(); + + if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); + if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) + return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreIds) + result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + return parser.parseRegion(*body, regionArgs); +} + +void SpatChannelSendManyOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + +void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputTypes); + return success(); +} + +void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getInput().getType()); +} + +ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parser.parseOperand(input)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperand(input, inputType, result.operands); +} + +void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getOutput().getType()); +} + +ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { + Type outputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputType); + return success(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp b/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp new file mode 100644 index 0000000..9abdba7 --- /dev/null +++ b/src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp @@ -0,0 +1,35 @@ +#include "mlir/IR/Block.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/LogicalResult.h" + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { + Block& block = getBody().front(); + if (!llvm::hasSingleElement(block)) + return failure(); + + auto yieldOp = dyn_cast(block.front()); + if (!yieldOp) + return failure(); + + for (Value yieldedValue : yieldOp.getOperands()) { + if (auto blockArg = dyn_cast(yieldedValue)) { + if (blockArg.getOwner() == &block) { + results.push_back(getOperand(blockArg.getArgNumber())); + continue; + } + } + results.push_back(yieldedValue); + } + return success(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp new file mode 100644 index 0000000..04d0ccf --- /dev/null +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -0,0 +1,433 @@ +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Support/LogicalResult.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +namespace { + +inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, + ArrayRef& matrixShape, + ArrayRef& vectorShape, + ArrayRef& outputShape) { + if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) + return emitter->emitError("matrix, vector and output must have rank 2"); + + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + if (N <= 0 || M <= 0) + return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); + + int64_t vectorM = vectorShape[0]; + int64_t vector1 = vectorShape[1]; + if (vectorM != M || vector1 != 1) + return emitter->emitError("vector shape must be (M, 1)"); + + int64_t outputN = outputShape[0]; + int64_t output1 = outputShape[1]; + if (outputN != N || output1 != 1) + return emitter->emitError("output shape must be (N, 1)"); + + return success(); +} + +inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, + ArrayRef& matrixShape, + ArrayRef& vectorShape, + ArrayRef& outputShape) { + if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) + return emitter->emitError("matrix, vector and output must have rank 4"); + + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + int64_t matrix1First = matrixShape[2]; + int64_t matrix1Second = matrixShape[3]; + if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) + return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); + + int64_t vector1First = vectorShape[0]; + int64_t vectorM = vectorShape[1]; + int64_t vector1Second = vectorShape[2]; + int64_t vector1Third = vectorShape[3]; + if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { + if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { + // This is ok, it was caused by the simplification of the concat error. + } + else { + return emitter->emitError("vector shape must be (1, M, 1, 1)"); + } + } + + int64_t output1First = outputShape[0]; + int64_t outputN = outputShape[1]; + int64_t output1Second = outputShape[2]; + int64_t output1Third = outputShape[3]; + if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) + return emitter->emitError("output shape must be (1, N, 1, 1)"); + + return success(); +} + +static FailureOr> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { + if (auto computeOp = dyn_cast(weightedOp->getParentOp())) + return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); + + if (auto coreOp = dyn_cast(weightedOp->getParentOp())) + return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); + + if (auto batchOp = dyn_cast(weightedOp->getParentOp())) { + if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) + return failure(); + return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); + } + + return failure(); +} + +static FailureOr getParentBatchLaneCount(Operation* op) { + auto batchOp = op->getParentOfType(); + if (!batchOp) + return failure(); + return batchOp.getLaneCount(); +} + +static LogicalResult verifyManyChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + size_t valueCount) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + if (channelIds.size() != valueCount) + return op->emitError("channel metadata length must match the number of values"); + return success(); +} + +static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { + if (types.empty()) + return op->emitError() << kind << " must carry at least one value"; + + Type firstType = types.front(); + for (Type type : types.drop_front()) + if (type != firstType) + return op->emitError() << kind << " values must all have the same type"; + return success(); +} + +static LogicalResult verifyBatchChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + + auto laneCount = getParentBatchLaneCount(op); + if (failed(laneCount)) + return op->emitError("must be nested inside spat.compute_batch"); + if (channelIds.size() != static_cast(*laneCount)) + return op->emitError("channel metadata length must match parent laneCount"); + + return success(); +} + +static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return op->emitError("body must terminate with spat.yield"); + if (outputTypes.empty()) { + if (yieldOp.getNumOperands() != 0) + return op->emitError("body yield must be empty when compute_batch has no results"); + } + else { + if (yieldOp.getNumOperands() != 1) + return op->emitError("body yield must produce exactly one value"); + if (yieldOp.getOperand(0).getType() != outputTypes[0]) + return op->emitError("body yield type must match output type"); + } + + for (auto& bodyOp : block) { + if (auto wvmm = dyn_cast(&bodyOp)) + if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) + return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); + if (auto wmvm = dyn_cast(&bodyOp)) + if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) + return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); + } + return success(); +} + +} // namespace + +LogicalResult SpatWeightedMVMOp::verify() { + auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + if (failed(matrixShapeOpt)) + return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op"); + auto matrixShape = *matrixShapeOpt; + auto vectorShape = getInput().getType().getShape(); + auto outputShape = getOutput().getType().getShape(); + + if (matrixShape.size() == 2) + return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); + if (matrixShape.size() == 4) + return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); + return emitError("matrix rank must be 2 or 4"); +} + +LogicalResult SpatWeightedVMMOp::verify() { + auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + if (failed(matrixShapeOpt)) + return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op"); + auto matrixShape = *matrixShapeOpt; + auto vectorShape = getInput().getType().getShape(); + auto outputShape = getOutput().getType().getShape(); + + if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) + return emitError("matrix, vector and output must have rank 2"); + + int64_t N = matrixShape[0]; + int64_t M = matrixShape[1]; + if (N <= 0 || M <= 0) + return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); + + int64_t vector1 = vectorShape[0]; + int64_t vectorN = vectorShape[1]; + if (vectorN != N || vector1 != 1) + return emitError("vector shape must be (1, N)"); + + int64_t output1 = outputShape[0]; + int64_t outputM = outputShape[1]; + if (outputM != M || output1 != 1) + return emitError("output shape must be (1, M)"); + + return success(); +} + +LogicalResult SpatVAddOp::verify() { + if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) + return failure(); + return OpTrait::impl::verifySameOperandsAndResultType(*this); +} + +LogicalResult SpatVMaxOp::verify() { + if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) + return failure(); + return OpTrait::impl::verifySameOperandsAndResultType(*this); +} + +LogicalResult SpatExtractRowsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) + return emitError("input must be a rank-2 shaped type"); + + int64_t numRows = inputType.getShape()[0]; + int64_t numCols = inputType.getShape()[1]; + Type elementType = inputType.getElementType(); + + if (numRows >= 0 && static_cast(getNumResults()) != numRows) + return emitError("number of outputs must match the number of input rows"); + + for (Type output : getResultTypes()) { + auto outputType = dyn_cast(output); + if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) + return emitError("outputs must all be rank-2 shaped types"); + if (outputType.getElementType() != elementType) + return emitError("output element types must match input element type"); + auto outputShape = outputType.getShape(); + if (outputShape[0] != 1) + return emitError("each output must have exactly one row"); + if (numCols >= 0 && outputShape[1] != numCols) + return emitError("output column count must match input column count"); + } + + return success(); +} + +LogicalResult SpatConcatOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + + auto outputType = dyn_cast(getOutput().getType()); + if (!outputType || !outputType.hasRank()) + return emitError("output must be a ranked shaped type"); + + int64_t axis = getAxis(); + int64_t rank = outputType.getRank(); + if (axis < 0 || axis >= rank) + return emitError("axis must be within the output rank"); + + int64_t concatenatedDimSize = 0; + bool concatenatedDimDynamic = false; + Type outputElementType = outputType.getElementType(); + + for (Value input : getInputs()) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasRank()) + return emitError("inputs must be ranked shaped types"); + if (inputType.getRank() != rank) + return emitError("all inputs must have the same rank as the output"); + if (inputType.getElementType() != outputElementType) + return emitError("all inputs must have the same element type as the output"); + + for (int64_t dim = 0; dim < rank; ++dim) { + if (dim == axis) + continue; + int64_t inputDim = inputType.getDimSize(dim); + int64_t outputDim = outputType.getDimSize(dim); + if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) + return emitError("non-concatenated dimensions must match the output shape"); + } + + int64_t inputConcatDim = inputType.getDimSize(axis); + if (ShapedType::isDynamic(inputConcatDim)) { + concatenatedDimDynamic = true; + continue; + } + concatenatedDimSize += inputConcatDim; + } + + int64_t outputConcatDim = outputType.getDimSize(axis); + if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) + return emitError("output concatenated dimension must equal the sum of input sizes"); + + return success(); +} + +LogicalResult SpatCompute::verify() { + auto& block = getBody().front(); + if (block.mightHaveTerminator()) { + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("ComputeOp must have a single yield operation"); + + auto resultTypes = getResultTypes(); + auto yieldTypes = yieldOp->getOperandTypes(); + if (resultTypes.size() != yieldTypes.size()) + return emitError("ComputeOp must have same number of results as yieldOp operands"); + + for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { + auto resultType = std::get<0>(it); + auto yieldType = std::get<1>(it); + + if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) + return emitError("ComputeOp output must be of the same type as yieldOp operand"); + + if (auto resultRankedType = dyn_cast(resultType)) { + if (auto yieldRankedType = dyn_cast(yieldType)) { + if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) + return emitError("ComputeOp output must have the same encoding as yieldOp operand"); + } + else { + return emitError("ComputeOp output has an encoding while yieldOp operand does not have one"); + } + } + else if (dyn_cast(yieldType)) { + return emitError("ComputeOp output must not have an encoding if yieldOp operand has one"); + } + } + } + + for (auto arg : block.getArguments()) + if (arg.use_empty()) + return emitError("ComputeOp block argument is not used"); + + return success(); +} + +LogicalResult SpatChannelSendManyOp::verify() { + if (failed(verifyManyChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); +} + +LogicalResult SpatChannelReceiveManyOp::verify() { + if (failed(verifyManyChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); +} + +LogicalResult SpatChannelSendBatchOp::verify() { + return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); +} + +LogicalResult SpatChannelReceiveBatchOp::verify() { + return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); +} + +LogicalResult SpatComputeBatch::verify() { + int32_t count = getLaneCount(); + if (count <= 0) + return emitError("laneCount must be positive"); + + auto laneCountSz = static_cast(count); + if (getWeights().size() % laneCountSz != 0) + return emitError("number of weights must be a multiple of laneCount"); + + if (!getInputs().empty() && getInputs().size() != laneCountSz) + return emitError("number of inputs must be either 0 or laneCount"); + if (!getOutputs().empty() && getOutputs().size() != laneCountSz) + return emitError("number of outputs must be either 0 or laneCount"); + + size_t weightsPerLane = getWeights().size() / laneCountSz; + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) { + Type weightType = getWeights()[weightIndex].getType(); + for (size_t lane = 1; lane < laneCountSz; ++lane) + if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType) + return emitError("corresponding weights across lanes must have the same type"); + } + + if (!getInputs().empty()) { + Type inputType = getInputs()[0].getType(); + for (Value in : getInputs().drop_front()) + if (in.getType() != inputType) + return emitError("all inputs must have the same type"); + } + + if (!getOutputs().empty()) { + Type outputType = getOutputs()[0].getType(); + for (Value out : getOutputs().drop_front()) + if (out.getType() != outputType) + return emitError("all outputs must have the same type"); + } + + if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) { + auto coreIdsAttr = dyn_cast(coreIdAttr); + if (!coreIdsAttr) + return emitError("compute_batch core_id attribute must be a dense i32 array"); + if (coreIdsAttr.size() != laneCountSz) + return emitError("compute_batch core_id array length must match laneCount"); + if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) + return emitError("compute_batch core_id values must be positive"); + } + + Block& block = getBody().front(); + if (getInputs().empty()) { + if (block.getNumArguments() != 0) + return emitError("compute_batch body must have no block arguments when there are no inputs"); + } + else { + if (block.getNumArguments() != 1) + return emitError("compute_batch body must have exactly one block argument"); + if (block.getArgument(0).getType() != getInputs()[0].getType()) + return emitError("body block argument type must match input type"); + } + + return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); +} + +} // namespace spatial +} // namespace onnx_mlir