refactor spatial ops
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m55s
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m55s
This commit is contained in:
@@ -4,6 +4,9 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
add_pim_library(SpatialOps
|
||||
Channels.cpp
|
||||
SpatialOps.cpp
|
||||
SpatialOpsAsm.cpp
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
912
src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp
Normal file
912
src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp
Normal file
@@ -0,0 +1,912 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ListDelimiter {
|
||||
Square,
|
||||
Paren
|
||||
};
|
||||
|
||||
static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseLSquare();
|
||||
return parser.parseLParen();
|
||||
}
|
||||
|
||||
static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseOptionalRSquare();
|
||||
return parser.parseOptionalRParen();
|
||||
}
|
||||
|
||||
static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
printer << (delimiter == ListDelimiter::Square ? "[" : "(");
|
||||
}
|
||||
|
||||
static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) {
|
||||
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||
}
|
||||
|
||||
template <typename EntryT, typename ParseEntryFn>
|
||||
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<EntryT>& entries,
|
||||
ParseEntryFn parseEntry) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
EntryT entry;
|
||||
if (parseEntry(entry))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
entries.push_back(entry);
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
if (parser.parseLSquare())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
|
||||
if (succeeded(parser.parseOptionalRSquare()))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename RangeT, typename PrintEntryFn>
|
||||
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||
for (size_t index = 0; index < entries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||
++runEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printEntry(entries[index]);
|
||||
size_t runLength = runEnd - index;
|
||||
if (runLength > 1)
|
||||
printer << " x" << runLength;
|
||||
index = runEnd;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printer << "[";
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
|
||||
auto findEqualRunEnd = [&](size_t start) {
|
||||
size_t end = start + 1;
|
||||
while (end < values.size() && values[end] == values[start])
|
||||
++end;
|
||||
return end;
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(index);
|
||||
size_t repeatCount = firstRunEnd - index;
|
||||
size_t progressionEnd = firstRunEnd;
|
||||
int64_t step = 0;
|
||||
IntT lastValue = values[index];
|
||||
|
||||
if (firstRunEnd < values.size()) {
|
||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
|
||||
progressionEnd = secondRunEnd;
|
||||
lastValue = values[firstRunEnd];
|
||||
size_t currentRunStart = secondRunEnd;
|
||||
while (currentRunStart < values.size()) {
|
||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||
if (currentRunEnd - currentRunStart != repeatCount)
|
||||
break;
|
||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||
break;
|
||||
lastValue = values[currentRunStart];
|
||||
progressionEnd = currentRunEnd;
|
||||
currentRunStart = currentRunEnd;
|
||||
}
|
||||
}
|
||||
else {
|
||||
step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
|
||||
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
|
||||
printer << values[index] << " to " << lastValue;
|
||||
if (step != 1)
|
||||
printer << " by " << step;
|
||||
if (repeatCount > 1)
|
||||
printer << " x" << repeatCount;
|
||||
index = progressionEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (repeatCount > 1) {
|
||||
printer << values[index] << " x" << repeatCount;
|
||||
index = firstRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
printer << values[index];
|
||||
index = firstRunEnd;
|
||||
}
|
||||
printer << "]";
|
||||
}
|
||||
|
||||
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::UnresolvedOperand lastOperand;
|
||||
if (parser.parseOperand(lastOperand))
|
||||
return failure();
|
||||
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
operands.push_back(firstOperand);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
OpAsmParser::UnresolvedOperand firstOperand;
|
||||
if (parser.parseOperand(firstOperand))
|
||||
return failure();
|
||||
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
break;
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index))
|
||||
break;
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
}
|
||||
|
||||
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||
Type firstType;
|
||||
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||
if (!firstTypeResult.has_value()) {
|
||||
if (allowEmpty)
|
||||
return success();
|
||||
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||
}
|
||||
if (failed(*firstTypeResult))
|
||||
return failure();
|
||||
|
||||
auto appendType = [&](Type type) -> ParseResult {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
types.push_back(type);
|
||||
return success();
|
||||
};
|
||||
|
||||
if (appendType(firstType))
|
||||
return failure();
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
Type nextType;
|
||||
if (parser.parseType(nextType) || appendType(nextType))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
printer << " channels ";
|
||||
printCompressedIntegerList(printer, channelIds);
|
||||
printer << " from ";
|
||||
printCompressedIntegerList(printer, sourceCoreIds);
|
||||
printer << " to ";
|
||||
printCompressedIntegerList(printer, targetCoreIds);
|
||||
}
|
||||
|
||||
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
|
||||
return parser.getBuilder().getDenseI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
|
||||
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static void buildImplicitRegionArgs(OpAsmParser& parser,
|
||||
ArrayRef<Type> inputTypes,
|
||||
SmallVectorImpl<std::string>& generatedNames,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
generatedNames.reserve(inputTypes.size());
|
||||
arguments.reserve(inputTypes.size());
|
||||
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
|
||||
generatedNames.push_back("arg" + std::to_string(index + 1));
|
||||
OpAsmParser::Argument arg;
|
||||
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
|
||||
arg.type = inputType;
|
||||
arguments.push_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getOutputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
OpAsmParser::UnresolvedOperand firstOutput;
|
||||
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
|
||||
if (firstOutputResult.has_value()) {
|
||||
if (failed(*firstOutputResult))
|
||||
return failure();
|
||||
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, outputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (outputs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
|
||||
|
||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatExtractRowsOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parser.parseType(inputType) || parser.parseArrow()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (parser.resolveOperand(input, inputType, result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis();
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t axis = 0;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
Type outputType;
|
||||
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (result.attributes.get("axis"))
|
||||
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
|
||||
|
||||
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " core_id " << coreIdAttr.getInt();
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<std::string> generatedArgNames;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"core_id cannot be specified both positionally and in attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreId)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " args = ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
printer << " core_ids ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
}
|
||||
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int32_t laneCount = 0;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<std::string> generatedArgNames;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
||||
return failure();
|
||||
}
|
||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
|
||||
|
||||
auto& builder = parser.getBuilder();
|
||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||
result.addAttribute(
|
||||
"operandSegmentSizes",
|
||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||
if (hasCoreIds)
|
||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
35
src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp
Normal file
35
src/PIM/Dialect/Spatial/SpatialOpsCanonicalization.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "mlir/IR/Block.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
|
||||
if (!yieldOp)
|
||||
return failure();
|
||||
|
||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArg.getOwner() == &block) {
|
||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push_back(yieldedValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
433
src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp
Normal file
433
src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp
Normal file
@@ -0,0 +1,433 @@
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitter->emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
|
||||
int64_t vectorM = vectorShape[0];
|
||||
int64_t vector1 = vectorShape[1];
|
||||
if (vectorM != M || vector1 != 1)
|
||||
return emitter->emitError("vector shape must be (M, 1)");
|
||||
|
||||
int64_t outputN = outputShape[0];
|
||||
int64_t output1 = outputShape[1];
|
||||
if (outputN != N || output1 != 1)
|
||||
return emitter->emitError("output shape must be (N, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
|
||||
return emitter->emitError("matrix, vector and output must have rank 4");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
int64_t matrix1First = matrixShape[2];
|
||||
int64_t matrix1Second = matrixShape[3];
|
||||
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
|
||||
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
|
||||
|
||||
int64_t vector1First = vectorShape[0];
|
||||
int64_t vectorM = vectorShape[1];
|
||||
int64_t vector1Second = vectorShape[2];
|
||||
int64_t vector1Third = vectorShape[3];
|
||||
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
|
||||
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
|
||||
// This is ok, it was caused by the simplification of the concat error.
|
||||
}
|
||||
else {
|
||||
return emitter->emitError("vector shape must be (1, M, 1, 1)");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t output1First = outputShape[0];
|
||||
int64_t outputN = outputShape[1];
|
||||
int64_t output1Second = outputShape[2];
|
||||
int64_t output1Third = outputShape[3];
|
||||
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
|
||||
return emitter->emitError("output shape must be (1, N, 1, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto batchOp = op->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return failure();
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.size() != valueCount)
|
||||
return op->emitError("channel metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
|
||||
Type firstType = types.front();
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return op->emitError("body must terminate with spat.yield");
|
||||
if (outputTypes.empty()) {
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
||||
}
|
||||
else {
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return op->emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
||||
return op->emitError("body yield type must match output type");
|
||||
}
|
||||
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
if (matrixShape.size() == 2)
|
||||
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
|
||||
if (matrixShape.size() == 4)
|
||||
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
|
||||
return emitError("matrix rank must be 2 or 4");
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorN = vectorShape[1];
|
||||
if (vectorN != N || vector1 != 1)
|
||||
return emitError("vector shape must be (1, N)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputM = outputShape[1];
|
||||
if (outputM != M || output1 != 1)
|
||||
return emitError("output shape must be (1, M)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatVAddOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVMaxOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatExtractRowsOp::verify() {
|
||||
auto inputType = dyn_cast<ShapedType>(getInput().getType());
|
||||
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
|
||||
return emitError("input must be a rank-2 shaped type");
|
||||
|
||||
int64_t numRows = inputType.getShape()[0];
|
||||
int64_t numCols = inputType.getShape()[1];
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
if (numRows >= 0 && static_cast<int64_t>(getNumResults()) != numRows)
|
||||
return emitError("number of outputs must match the number of input rows");
|
||||
|
||||
for (Type output : getResultTypes()) {
|
||||
auto outputType = dyn_cast<ShapedType>(output);
|
||||
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
|
||||
return emitError("outputs must all be rank-2 shaped types");
|
||||
if (outputType.getElementType() != elementType)
|
||||
return emitError("output element types must match input element type");
|
||||
auto outputShape = outputType.getShape();
|
||||
if (outputShape[0] != 1)
|
||||
return emitError("each output must have exactly one row");
|
||||
if (numCols >= 0 && outputShape[1] != numCols)
|
||||
return emitError("output column count must match input column count");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!outputType || !outputType.hasRank())
|
||||
return emitError("output must be a ranked shaped type");
|
||||
|
||||
int64_t axis = getAxis();
|
||||
int64_t rank = outputType.getRank();
|
||||
if (axis < 0 || axis >= rank)
|
||||
return emitError("axis must be within the output rank");
|
||||
|
||||
int64_t concatenatedDimSize = 0;
|
||||
bool concatenatedDimDynamic = false;
|
||||
Type outputElementType = outputType.getElementType();
|
||||
|
||||
for (Value input : getInputs()) {
|
||||
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||
if (!inputType || !inputType.hasRank())
|
||||
return emitError("inputs must be ranked shaped types");
|
||||
if (inputType.getRank() != rank)
|
||||
return emitError("all inputs must have the same rank as the output");
|
||||
if (inputType.getElementType() != outputElementType)
|
||||
return emitError("all inputs must have the same element type as the output");
|
||||
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
if (dim == axis)
|
||||
continue;
|
||||
int64_t inputDim = inputType.getDimSize(dim);
|
||||
int64_t outputDim = outputType.getDimSize(dim);
|
||||
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
|
||||
return emitError("non-concatenated dimensions must match the output shape");
|
||||
}
|
||||
|
||||
int64_t inputConcatDim = inputType.getDimSize(axis);
|
||||
if (ShapedType::isDynamic(inputConcatDim)) {
|
||||
concatenatedDimDynamic = true;
|
||||
continue;
|
||||
}
|
||||
concatenatedDimSize += inputConcatDim;
|
||||
}
|
||||
|
||||
int64_t outputConcatDim = outputType.getDimSize(axis);
|
||||
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
|
||||
return emitError("output concatenated dimension must equal the sum of input sizes");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("ComputeOp must have a single yield operation");
|
||||
|
||||
auto resultTypes = getResultTypes();
|
||||
auto yieldTypes = yieldOp->getOperandTypes();
|
||||
if (resultTypes.size() != yieldTypes.size())
|
||||
return emitError("ComputeOp must have same number of results as yieldOp operands");
|
||||
|
||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||
auto resultType = std::get<0>(it);
|
||||
auto yieldType = std::get<1>(it);
|
||||
|
||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
||||
return emitError("ComputeOp output must be of the same type as yieldOp operand");
|
||||
|
||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
||||
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
|
||||
}
|
||||
else {
|
||||
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
||||
}
|
||||
}
|
||||
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
||||
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
int32_t count = getLaneCount();
|
||||
if (count <= 0)
|
||||
return emitError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
if (getWeights().size() % laneCountSz != 0)
|
||||
return emitError("number of weights must be a multiple of laneCount");
|
||||
|
||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
||||
return emitError("number of inputs must be either 0 or laneCount");
|
||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
||||
return emitError("number of outputs must be either 0 or laneCount");
|
||||
|
||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
||||
Type weightType = getWeights()[weightIndex].getType();
|
||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
||||
return emitError("corresponding weights across lanes must have the same type");
|
||||
}
|
||||
|
||||
if (!getInputs().empty()) {
|
||||
Type inputType = getInputs()[0].getType();
|
||||
for (Value in : getInputs().drop_front())
|
||||
if (in.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
}
|
||||
|
||||
if (!getOutputs().empty()) {
|
||||
Type outputType = getOutputs()[0].getType();
|
||||
for (Value out : getOutputs().drop_front())
|
||||
if (out.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
if (!coreIdsAttr)
|
||||
return emitError("compute_batch core_id attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != laneCountSz)
|
||||
return emitError("compute_batch core_id array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch core_id values must be positive");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (getInputs().empty()) {
|
||||
if (block.getNumArguments() != 0)
|
||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
||||
}
|
||||
else {
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("compute_batch body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
||||
return emitError("body block argument type must match input type");
|
||||
}
|
||||
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user