Files
Raptor/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp
T
NiccoloN a50e77ff38
Validate Operations / validate-operations (push) Has been cancelled
refactorone
2026-05-20 19:06:41 +02:00

457 lines
18 KiB
C++

#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/LogicalResult.h"
#include <string>
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
namespace {
using namespace onnx_mlir::compact_asm;
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
}
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
ArrayRef<Type> weightTypes,
ArrayRef<Type> outputTypes,
OpAsmParser::Argument& laneArg,
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
Builder& builder) {
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
applyArgumentTypes(outputTypes, outputArgs);
llvm::append_range(regionArgs, inputArgs);
llvm::append_range(regionArgs, outputArgs);
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
return failure();
}
return success();
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parser.parseType(inputType) || parser.parseArrow()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (parser.resolveOperand(input, inputType, result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void SpatConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis();
printer << " ";
printCompressedValueSequence(printer, getInputs());
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t axis = 0;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
Type outputType;
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure();
if (parseCompressedOperandSequence(parser, inputs))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (result.attributes.get("axis"))
return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict");
result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis));
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount();
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
SmallVector<BlockArgument> outputArgs;
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index)
outputArgs.push_back(getOutputArgument(index));
printBlockArgumentList(printer, outputArgs);
}
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
void SpatInParallelOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
printer.printOptionalAttrDict((*this)->getAttrs());
}
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
auto& builder = parser.getBuilder();
std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<OpAsmParser::Argument, 4> regionArgs;
if (parser.parseRegion(*region, regionArgs))
return failure();
if (region->empty())
OpBuilder(builder.getContext()).createBlock(region.get());
result.addRegion(std::move(region));
return parser.parseOptionalAttrDict(result.attributes);
}
} // namespace spatial
} // namespace onnx_mlir