This commit is contained in:
+34
-16
@@ -2,6 +2,7 @@
|
||||
#define PIM_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
@@ -24,7 +25,8 @@ def PimTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute a block on a PIM core";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
I32Attr:$coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute equivalent batched core bodies";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
I32Attr:$sourceCoreId
|
||||
Index:$sourceCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$deviceTargetOffset,
|
||||
Index:$hostSourceOffset,
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict
|
||||
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$hostTargetOffset,
|
||||
Index:$deviceSourceOffset,
|
||||
PimTensor:$hostTarget,
|
||||
PimTensor:$deviceSource,
|
||||
I32Attr:$hostTargetOffset,
|
||||
I32Attr:$deviceSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict
|
||||
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Vector-matrix multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$weight,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
|
||||
type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,41 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
||||
@@ -20,6 +20,80 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
|
||||
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
||||
printer << " " << keyword << " ";
|
||||
printCompressedIntegerList(printer, coreIds);
|
||||
@@ -33,15 +107,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
|
||||
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " coreId " << getCoreId();
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<Type> weightTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (hasCoreId && result.attributes.get("coreId"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
|
||||
if (hasCoreId)
|
||||
result.addAttribute("coreId", getI32Attr(parser, coreId));
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
return parser.parseRegion(*body, weightArgs);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer << " = 0 to " << getLaneCount() << " ";
|
||||
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
||||
@@ -49,51 +184,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
|
||||
printer << " -> ()";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|
||||
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|
||||
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
if (parser.parseRegion(*body))
|
||||
return failure();
|
||||
|
||||
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|
||||
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|
||||
|| parser.parseLParen() || parser.parseRParen())
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -110,7 +251,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -14,6 +16,52 @@ namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (isa<PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return blockArgument.getParentRegion();
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp ? definingOp->getParentRegion() : nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|
||||
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||
}
|
||||
@@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimCoreOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != getWeights().size())
|
||||
return emitError("core body must have one block argument per weight");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||
}
|
||||
|
||||
LogicalResult PimCoreBatchOp::verify() {
|
||||
if (getLaneCount() <= 0)
|
||||
return emitError("laneCount must be positive");
|
||||
Block& block = getBody().front();
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("core_batch first block argument must have index type");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
@@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() {
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
return emitError("weight must be a shaped value");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
|
||||
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetMemRef.getType(),
|
||||
memCopyHostToDevOp.getDeviceTargetOffset(),
|
||||
memCopyHostToDevOp.getHostSourceOffset(),
|
||||
deviceTargetMemRef,
|
||||
hostSourceMemRef,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
hostTargetMemRef.getType(),
|
||||
memCopyDevToHostOp.getHostTargetOffset(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffset(),
|
||||
hostTargetMemRef,
|
||||
deviceSourceMemRef,
|
||||
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdAttr());
|
||||
sendOp.getTargetCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned weightIndex = bbArg.getArgNumber();
|
||||
return {
|
||||
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return {};
|
||||
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
unsigned operandIndex = argNumber - 1;
|
||||
if (argNumber > weightCount + 1)
|
||||
operandIndex = weightCount + (argNumber - 1 - weightCount);
|
||||
|
||||
return {
|
||||
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
||||
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return failure();
|
||||
|
||||
Value tiedOperand;
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
if (argNumber <= weightCount)
|
||||
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
|
||||
else
|
||||
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
|
||||
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
||||
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
||||
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
|
||||
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||
return getConstantI64(sendOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI64(receiveOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getSourceCoreId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getTargetCoreId());
|
||||
}
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
|
||||
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#define SPATIAL_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def SpatialDialect : Dialect {
|
||||
let name = "spat";
|
||||
@@ -22,7 +26,9 @@ def SpatTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
def SpatCompute : SpatOp<"compute",
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -36,14 +42,20 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
[SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
@@ -57,10 +69,41 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
::mlir::BlockArgument getOutputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||
Pure,
|
||||
Terminator,
|
||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||
HasParent<"SpatComputeBatch">,
|
||||
] # GraphRegionNoTerminator.traits> {
|
||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
||||
::mlir::OpResult getParentResult(int64_t idx);
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a compute region";
|
||||
|
||||
@@ -110,14 +153,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId,
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -125,9 +168,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -135,31 +178,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -167,44 +212,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -212,16 +263,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -229,7 +282,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -240,7 +295,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -251,7 +306,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -259,7 +314,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -270,7 +325,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,74 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
if (index == 0) {
|
||||
setNameFn(getOutputArgument(index), "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region* bodyRegion = result.addRegion();
|
||||
builder.createBlock(bodyRegion);
|
||||
}
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
|
||||
return getRegion().front().getOperations();
|
||||
}
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
printer << " channels ";
|
||||
printCompressedIntegerList(printer, channelIds);
|
||||
printer << " from ";
|
||||
printCompressedIntegerList(printer, sourceCoreIds);
|
||||
printer << " to ";
|
||||
printCompressedIntegerList(printer, targetCoreIds);
|
||||
}
|
||||
|
||||
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
|
||||
return parser.getBuilder().getDenseI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
@@ -47,94 +31,89 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
|
||||
ArrayRef<Type> weightTypes,
|
||||
ArrayRef<Type> outputTypes,
|
||||
OpAsmParser::Argument& laneArg,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
|
||||
Builder& builder) {
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
applyArgumentTypes(outputTypes, outputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
llvm::append_range(regionArgs, outputArgs);
|
||||
}
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -243,9 +222,17 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
@@ -264,6 +251,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -272,10 +260,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
@@ -292,9 +281,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
@@ -313,19 +304,39 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index)
|
||||
outputArgs.push_back(getOutputArgument(index));
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
@@ -337,10 +348,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
@@ -350,7 +358,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -359,14 +372,21 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
@@ -381,10 +401,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -403,119 +428,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
auto& builder = parser.getBuilder();
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
if (parser.parseRegion(*region, regionArgs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
if (region->empty())
|
||||
OpBuilder(builder.getContext()).createBlock(region.get());
|
||||
result.addRegion(std::move(region));
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
@@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
@@ -105,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
|
||||
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
APInt constantValue;
|
||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||
}
|
||||
|
||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || isConstantIndexLike(value))
|
||||
return true;
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
return false;
|
||||
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
||||
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
||||
BlockArgument laneArg,
|
||||
StringRef kind) {
|
||||
RankedTensorType sourceType = sliceOp.getSourceType();
|
||||
RankedTensorType destType = sliceOp.getDestType();
|
||||
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.empty())
|
||||
if (channelCount == 0)
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -125,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult
|
||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
||||
if (channelCount != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult verifyTensorBatchChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -169,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
@@ -177,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return op->emitError("body must terminate with spat.yield");
|
||||
if (outputTypes.empty()) {
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
|
||||
<< " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
if (batchOp.getNumResults() == 0) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
||||
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
||||
}
|
||||
else {
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return op->emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
||||
return op->emitError("body yield type must match output type");
|
||||
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -206,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -221,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -347,13 +437,26 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||
return op->emitError("ComputeResult used directly inside another Compute");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -386,9 +489,11 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (getInputArgument(inputIndex).use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
@@ -397,44 +502,46 @@ LogicalResult SpatCompute::verify() {
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
@@ -444,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
if (getWeights().size() % laneCountSz != 0)
|
||||
return emitError("number of weights must be a multiple of laneCount");
|
||||
|
||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
||||
return emitError("number of inputs must be either 0 or laneCount");
|
||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
||||
return emitError("number of outputs must be either 0 or laneCount");
|
||||
|
||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
||||
Type weightType = getWeights()[weightIndex].getType();
|
||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
||||
return emitError("corresponding weights across lanes must have the same type");
|
||||
}
|
||||
|
||||
if (!getInputs().empty()) {
|
||||
Type inputType = getInputs()[0].getType();
|
||||
for (Value in : getInputs().drop_front())
|
||||
if (in.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
}
|
||||
|
||||
if (!getOutputs().empty()) {
|
||||
Type outputType = getOutputs()[0].getType();
|
||||
for (Value out : getOutputs().drop_front())
|
||||
if (out.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
@@ -482,27 +560,64 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
DenseSet<int32_t> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch coreIds values must be distinct");
|
||||
return emitError("compute_batch coreIds values must be unique");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (getInputs().empty()) {
|
||||
if (block.getNumArguments() != 0)
|
||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
else {
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("compute_batch body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
||||
return emitError("body block argument type must match input type");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
BlockArgument blockArg = getOutputArgument(resultIndex);
|
||||
if (blockArg.getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
}
|
||||
|
||||
LogicalResult SpatInParallelOp::verify() {
|
||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return emitOpError("expected spat.compute_batch parent");
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
for (OpOperand& destination : destinations)
|
||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
+777
-107
File diff suppressed because it is too large
Load Diff
@@ -167,21 +167,20 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
||||
}
|
||||
|
||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) {
|
||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) {
|
||||
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights()))
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
|
||||
targetWeightIndices[weight].push_back(weightIndex);
|
||||
|
||||
DenseMap<Value, size_t> usedSourceWeightOccurrences;
|
||||
SmallVector<size_t> sourceToTargetIndex;
|
||||
sourceToTargetIndex.reserve(sourceWeights.size());
|
||||
auto targetWeights = target.getWeightsMutable();
|
||||
for (Value weight : sourceWeights) {
|
||||
size_t occurrence = usedSourceWeightOccurrences[weight]++;
|
||||
auto& matchingIndices = targetWeightIndices[weight];
|
||||
if (occurrence >= matchingIndices.size()) {
|
||||
size_t newIndex = target.getWeights().size();
|
||||
targetWeights.append(weight);
|
||||
size_t newIndex = targetWeights.size();
|
||||
targetWeights.push_back(weight);
|
||||
matchingIndices.push_back(newIndex);
|
||||
sourceToTargetIndex.push_back(newIndex);
|
||||
continue;
|
||||
@@ -213,37 +212,36 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
auto& computeUse = *compute->getUses().begin();
|
||||
auto child = cast<SpatCompute>(computeUse.getOwner());
|
||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
auto childInputIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
SmallVector<Value> mergedWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(mergedWeights, child.getWeights());
|
||||
SmallVector<Value> mergedInputs(compute.getInputs().begin(), compute.getInputs().end());
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), mergedWeights, mergedInputs);
|
||||
Block* newBody = rewriter.createBlock(&newCompute.getBodyRegion());
|
||||
for (Value weight : mergedWeights)
|
||||
newBody->addArgument(weight.getType(), loc);
|
||||
for (Value input : mergedInputs)
|
||||
newBody->addArgument(input.getType(), loc);
|
||||
|
||||
IRMapping mapper;
|
||||
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights());
|
||||
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs()))
|
||||
mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex));
|
||||
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
|
||||
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex]));
|
||||
mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]));
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPointToEnd(newBody);
|
||||
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
|
||||
for (Operation& op : compute.getBody().front().without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
|
||||
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
auto remapWeightIndex = [&](auto weightedOp) {
|
||||
auto oldIndex = weightedOp.getWeightIndex();
|
||||
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
|
||||
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
|
||||
};
|
||||
|
||||
for (auto& op : child.getBody().front()) {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
|
||||
remapWeightIndex(weightedMvmOp);
|
||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
|
||||
remapWeightIndex(weightedVmmOp);
|
||||
}
|
||||
rewriter.setInsertionPointToEnd(newBody);
|
||||
for (auto& op : child.getBody().front())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
child.replaceAllUsesWith(newCompute);
|
||||
toErase.insert(child);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
@@ -61,6 +62,66 @@ std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
@@ -206,8 +267,215 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
struct BatchYieldInfo {
|
||||
Value yieldedValue;
|
||||
tensor::ParallelInsertSliceOp insertSlice;
|
||||
};
|
||||
|
||||
static bool isHostOnlyBatchResultUser(Operation* user) {
|
||||
return isa<func::ReturnOp,
|
||||
spatial::SpatConcatOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::CastOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp>(user);
|
||||
}
|
||||
|
||||
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
|
||||
Block& block = batchOp.getBody().front();
|
||||
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
|
||||
if (!inParallel)
|
||||
return failure();
|
||||
|
||||
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
|
||||
for (Operation& op : inParallel.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSlice)
|
||||
return failure();
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &block)
|
||||
return failure();
|
||||
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
|
||||
}
|
||||
return batchYieldByOutputArg;
|
||||
}
|
||||
|
||||
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return failure();
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
|
||||
batchOp.getWeights(),
|
||||
batchOp.getInputs());
|
||||
newBatch.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
|
||||
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
|
||||
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
|
||||
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
|
||||
}
|
||||
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
return newBatch;
|
||||
}
|
||||
|
||||
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
|
||||
|
||||
for (auto batchOp : batches) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
continue;
|
||||
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
|
||||
|
||||
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
|
||||
if (failed(batchYieldInfo))
|
||||
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
|
||||
|
||||
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
|
||||
if (failed(newBatch))
|
||||
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
Block& newBlock = newBatch->getBody().front();
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
|
||||
auto oldIt = oldBlock.begin();
|
||||
auto newIt = newBlock.begin();
|
||||
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
|
||||
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
|
||||
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
|
||||
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
|
||||
auto yieldInfoIt = batchYieldInfo->find(outputArg);
|
||||
if (yieldInfoIt == batchYieldInfo->end())
|
||||
return batchOp.emitOpError(
|
||||
"missing yielded value for compute_batch result during communication materialization");
|
||||
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
|
||||
|
||||
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
|
||||
SmallVector<OpOperand*> hostUses;
|
||||
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
|
||||
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
|
||||
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
|
||||
continue;
|
||||
}
|
||||
if (isHostOnlyBatchResultUser(use.getOwner())) {
|
||||
hostUses.push_back(&use);
|
||||
continue;
|
||||
}
|
||||
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
|
||||
<< ": " << use.getOwner()->getName();
|
||||
}
|
||||
|
||||
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
|
||||
if (uses.empty())
|
||||
return success();
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
channelIds.reserve(sourceCoreIds.size());
|
||||
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
|
||||
channelIds.push_back(nextChannelId++);
|
||||
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
sendChannelIdValues,
|
||||
sendSourceCoreIdValues,
|
||||
sendTargetCoreIdValues,
|
||||
mappedYieldedValue);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfter(newBatch->getOperation());
|
||||
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
batchOp.getResult(resultIndex).getType(),
|
||||
receiveChannelIdValues,
|
||||
receiveSourceCoreIdValues,
|
||||
receiveTargetCoreIdValues);
|
||||
for (OpOperand* use : uses)
|
||||
use->set(received.getOutput());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
return success();
|
||||
};
|
||||
|
||||
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
|
||||
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
|
||||
if (failed(createReceiveForUses(uses, targetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!hostUses.empty()) {
|
||||
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
|
||||
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseSet<Operation*> consumed;
|
||||
DenseMap<Operation*, size_t> computeOrder;
|
||||
@@ -316,8 +584,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||
entries.push_back(
|
||||
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||
BatchReceiveEntry entry;
|
||||
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
@@ -331,12 +601,15 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getOutput().getType(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues);
|
||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
@@ -351,7 +624,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||
BatchSendEntry entry;
|
||||
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
@@ -365,11 +641,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues,
|
||||
mapper.lookup(sendOp.getInput()));
|
||||
continue;
|
||||
}
|
||||
@@ -452,6 +731,11 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
|
||||
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
@@ -30,7 +31,7 @@ enum class RegularStepKind {
|
||||
|
||||
struct RegularStep {
|
||||
RegularStepKind kind;
|
||||
int32_t weightIndex = 0;
|
||||
Value weight;
|
||||
Value invariantOperand;
|
||||
Type resultType;
|
||||
};
|
||||
@@ -73,15 +74,90 @@ static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||
}
|
||||
|
||||
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
||||
SmallVectorImpl<int32_t>& sourceCoreIds,
|
||||
SmallVectorImpl<int32_t>& targetCoreIds,
|
||||
uint64_t channelId,
|
||||
uint32_t sourceCoreId,
|
||||
uint32_t targetCoreId) {
|
||||
channelIds.push_back(static_cast<int64_t>(channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Operation*> getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) {
|
||||
SmallVector<Operation*> defs;
|
||||
defs.reserve(metadataOperandCount);
|
||||
for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) {
|
||||
Operation* def = channelOp->getOperand(operandIndex).getDefiningOp();
|
||||
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(def);
|
||||
if (!constantOp || def->getBlock() != channelOp->getBlock())
|
||||
continue;
|
||||
defs.push_back(def);
|
||||
}
|
||||
llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); });
|
||||
return defs;
|
||||
}
|
||||
|
||||
static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) {
|
||||
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||
metadataDef->moveBefore(insertionPoint);
|
||||
channelOp->moveBefore(insertionPoint);
|
||||
}
|
||||
|
||||
static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) {
|
||||
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
|
||||
metadataDef->moveBefore(block, insertionPoint);
|
||||
channelOp->moveBefore(block, insertionPoint);
|
||||
}
|
||||
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
@@ -196,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||
return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand
|
||||
&& lhs.resultType == rhs.resultType;
|
||||
}
|
||||
|
||||
@@ -227,8 +303,7 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
chunk.input = startOp.getInput();
|
||||
chunk.output = startOp.getOutput();
|
||||
chunk.ops.push_back(startOp.getOperation());
|
||||
chunk.steps.push_back(
|
||||
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()});
|
||||
|
||||
Value currentValue = startOp.getOutput();
|
||||
while (currentValue.hasOneUse()) {
|
||||
@@ -241,9 +316,9 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
break;
|
||||
|
||||
if (vaddOp.getLhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||
else if (vaddOp.getRhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||
chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||
else
|
||||
break;
|
||||
|
||||
@@ -255,7 +330,8 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
static RegularCompactionResult
|
||||
compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run, OperationFolder& constantFolder) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
RegularCompactionResult result;
|
||||
@@ -275,9 +351,9 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
||||
auto packedInit = tensor::EmptyOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
|
||||
auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder);
|
||||
auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast<int64_t>(run.size()), constantFolder);
|
||||
auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder);
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
@@ -290,8 +366,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
||||
|
||||
Value inputRowOffset = iv;
|
||||
if (inputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
|
||||
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder);
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
@@ -320,8 +395,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
|
||||
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
|
||||
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder);
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
@@ -389,35 +463,50 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
Block& block = compute.getBody().front();
|
||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||
Operation* firstForwardedSend = nullptr;
|
||||
|
||||
for (Operation& op : block) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
||||
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId)
|
||||
&& sourceCoreId == static_cast<uint32_t>(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||
if (!firstForwardedSend)
|
||||
firstForwardedSend = sendOp.getOperation();
|
||||
uint64_t key = getEndpointKey(sourceCoreId, targetCoreId);
|
||||
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||
|| targetCoreId != static_cast<uint32_t>(coreId) || sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), sourceCoreId);
|
||||
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||
else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp))
|
||||
moves.push_back({receiveOp, firstForwardedSend});
|
||||
}
|
||||
|
||||
for (auto [receiveOp, insertionPoint] : moves)
|
||||
receiveOp->moveBefore(insertionPoint);
|
||||
moveScalarChannelBundleBefore(receiveOp, insertionPoint);
|
||||
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|
||||
|| sourceCoreId >= static_cast<uint32_t>(coreId)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
@@ -425,18 +514,32 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
uint64_t currentChannelId = 0;
|
||||
uint32_t currentSourceCoreId = 0;
|
||||
uint32_t currentTargetCoreId = 0;
|
||||
return current.getOutput().getType() == outputType
|
||||
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
||||
&& getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId)
|
||||
&& currentSourceCoreId < static_cast<uint32_t>(coreId);
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||
uint64_t lhsChannelId = 0;
|
||||
uint32_t lhsSourceCoreId = 0;
|
||||
uint32_t lhsTargetCoreId = 0;
|
||||
uint64_t rhsChannelId = 0;
|
||||
uint32_t rhsSourceCoreId = 0;
|
||||
uint32_t rhsTargetCoreId = 0;
|
||||
bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId);
|
||||
bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId);
|
||||
if (!lhsHasMetadata || !rhsHasMetadata)
|
||||
return false;
|
||||
return lhsSourceCoreId > rhsSourceCoreId;
|
||||
});
|
||||
Block::iterator insertIt = run.end;
|
||||
for (auto op : sorted)
|
||||
op->moveBefore(&block, insertIt);
|
||||
moveScalarChannelBundleBefore(op, &block, insertIt);
|
||||
}
|
||||
|
||||
it = run.end;
|
||||
@@ -446,6 +549,7 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
|
||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
@@ -461,7 +565,14 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
bool hasRepeatedEndpoint = false;
|
||||
DenseSet<uint64_t> seenEndpoints;
|
||||
for (auto op : run.ops) {
|
||||
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
}
|
||||
uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId);
|
||||
if (!seenEndpoints.insert(endpointKey).second) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
@@ -478,8 +589,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops)) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
sortedEntries.clear();
|
||||
break;
|
||||
}
|
||||
sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId});
|
||||
}
|
||||
if (sortedEntries.empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -488,8 +611,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
@@ -506,13 +630,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(
|
||||
rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues);
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(concatOp,
|
||||
concatStartIndex,
|
||||
@@ -551,8 +673,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
for (auto op : run.ops) {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
|
||||
sortedEntries.clear();
|
||||
break;
|
||||
}
|
||||
sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId});
|
||||
}
|
||||
if (sortedEntries.empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -563,20 +697,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendTensorOp::create(
|
||||
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -606,9 +740,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> channelIds;
|
||||
SmallVector<Value> sourceCoreIds;
|
||||
SmallVector<Value> targetCoreIds;
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
@@ -629,13 +763,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(
|
||||
rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds);
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
@@ -663,9 +792,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> channelIds;
|
||||
SmallVector<Value> sourceCoreIds;
|
||||
SmallVector<Value> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops) {
|
||||
@@ -678,12 +807,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
spatial::SpatChannelSendTensorBatchOp::create(
|
||||
rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput);
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -700,6 +825,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
|
||||
void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
auto compactInBlock = [&](Block& block) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
@@ -740,7 +866,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
for (const RegularChunk& chunk : run)
|
||||
originalOpCount += chunk.ops.size();
|
||||
|
||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder);
|
||||
if (result.changed) {
|
||||
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||
if (!result.resumeAfter) {
|
||||
@@ -763,6 +889,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
|
||||
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
@@ -784,7 +911,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
if (current.getWeight() != wvmmOp.getWeight()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||
@@ -851,9 +978,9 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
||||
auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder);
|
||||
auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder);
|
||||
auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder);
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
auto loop =
|
||||
@@ -868,7 +995,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
Value sourceRow = iv;
|
||||
if (firstRow != 0) {
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
||||
auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||
}
|
||||
|
||||
@@ -883,7 +1010,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||
|
||||
@@ -23,31 +23,31 @@ using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
Weight getComputeBodyWeight(Region &body) {
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto &block : body)
|
||||
for ([[maybe_unused]] auto &op : block)
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto &block : body)
|
||||
for (auto &op : block)
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||
bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation *user : result.getUsers()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
@@ -66,7 +66,7 @@ bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (const ComputeGraphEdge &edge : edges) {
|
||||
for (const ComputeGraphEdge& edge : edges) {
|
||||
if (edge.source == edge.target)
|
||||
continue;
|
||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
@@ -76,9 +76,9 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (const auto &[key, weight] : edgeWeights)
|
||||
for (const auto& [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source < rhs.source;
|
||||
return lhs.target < rhs.target;
|
||||
@@ -88,33 +88,33 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
||||
|
||||
} // namespace
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
||||
static_cast<CrossbarUsage>(instance.laneCount));
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
ComputeGraph graph;
|
||||
|
||||
for (Region ®ion : entryOp->getRegions()) {
|
||||
for (Block &block : region) {
|
||||
for (Operation &op : block) {
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
continue;
|
||||
}
|
||||
@@ -135,9 +135,21 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||
auto producerInstance = getComputeProducerInstance(input);
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto producerInstance = getComputeProducerInstance(input, &node.instance);
|
||||
if (!producerInstance)
|
||||
continue;
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
@@ -152,7 +164,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.successors.assign(graph.nodes.size(), {});
|
||||
graph.predecessors.assign(graph.nodes.size(), {});
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||
}
|
||||
@@ -160,7 +172,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
return graph;
|
||||
}
|
||||
|
||||
bool verifyAcyclic(const ComputeGraph &graph) {
|
||||
bool verifyAcyclic(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readyNodes;
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
@@ -174,7 +186,7 @@ bool verifyAcyclic(const ComputeGraph &graph) {
|
||||
size_t node = readyNodes.front();
|
||||
readyNodes.pop();
|
||||
++visited;
|
||||
for (const auto &[child, weight] : graph.successors[node]) {
|
||||
for (const auto& [child, weight] : graph.successors[node]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
|
||||
+101
-38
@@ -1,6 +1,8 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
@@ -18,48 +20,91 @@ size_t getSchedulingCpuBudget() {
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
return static_cast<size_t>(laneCount);
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = batch.getLaneCount();
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
|
||||
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = batch.getLaneCount();
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
|
||||
return {batch.getOperation(), lane, 1};
|
||||
}
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
Operation *op = value.getDefiningOp();
|
||||
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
|
||||
if (extract.getMixedOffsets().empty())
|
||||
return std::nullopt;
|
||||
|
||||
OpFoldResult offset = extract.getMixedOffsets().front();
|
||||
if (Attribute attr = llvm::dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!intAttr || intAttr.getInt() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(intAttr.getInt());
|
||||
}
|
||||
|
||||
Value offsetValue = llvm::cast<Value>(offset);
|
||||
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||
if (constantIndex.value() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(constantIndex.value());
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static std::optional<ProducerValueRef> getResultfulBatchProducerValueRef(SpatComputeBatch batch,
|
||||
const ComputeInstance* consumerInstance) {
|
||||
if (!consumerInstance)
|
||||
return std::nullopt;
|
||||
if (!isa<SpatComputeBatch>(consumerInstance->op))
|
||||
return std::nullopt;
|
||||
if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast<uint32_t>(batch.getLaneCount()))
|
||||
return std::nullopt;
|
||||
return ProducerValueRef {
|
||||
{batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount},
|
||||
0
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(Value value, const ComputeInstance* consumerInstance) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
Value source = extract.getSource();
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
|
||||
if (batch && batch.getNumResults() != 0) {
|
||||
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
|
||||
if (*lane >= static_cast<uint32_t>(batch.getLaneCount()))
|
||||
return std::nullopt;
|
||||
return ProducerValueRef {
|
||||
{batch.getOperation(), *lane, 1},
|
||||
0
|
||||
};
|
||||
}
|
||||
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||
}
|
||||
|
||||
value = source;
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
return ProducerValueRef {
|
||||
ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
if (batch.getNumResults() != 0)
|
||||
return getResultfulBatchProducerValueRef(batch, consumerInstance);
|
||||
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||
size_t resultIndex = lane - instance.laneStart;
|
||||
@@ -69,42 +114,60 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) {
|
||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value, consumerInstance))
|
||||
return producer->instance;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getInputs().begin(), batch.getInputs().end());
|
||||
|
||||
assert(batch.getInputs().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||
&& "resultless compute_batch inputs must be evenly partitioned by lane");
|
||||
size_t inputsPerLane = batch.getInputs().size() / static_cast<size_t>(batch.getLaneCount());
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getInputs().empty())
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
inputs.reserve(instance.laneCount * inputsPerLane);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||
size_t firstInput = static_cast<size_t>(lane) * inputsPerLane;
|
||||
inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getWeights().begin(), batch.getWeights().end());
|
||||
|
||||
assert(batch.getWeights().size() % static_cast<size_t>(batch.getLaneCount()) == 0
|
||||
&& "resultless compute_batch weights must be evenly partitioned by lane");
|
||||
size_t weightsPerLane = batch.getWeights().size() / static_cast<size_t>(batch.getLaneCount());
|
||||
llvm::SmallVector<Value, 4> weights;
|
||||
weights.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
weights.push_back(batch.getWeights()[lane]);
|
||||
weights.reserve(instance.laneCount * weightsPerLane);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
|
||||
size_t firstWeight = static_cast<size_t>(lane) * weightsPerLane;
|
||||
weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane);
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
if (batch.getNumResults() != 0)
|
||||
return llvm::SmallVector<Value, 4>(batch.getResults().begin(), batch.getResults().end());
|
||||
|
||||
llvm::SmallVector<Value, 4> outputs;
|
||||
outputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
@@ -113,14 +176,14 @@ llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance
|
||||
return outputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance) {
|
||||
llvm::SmallVector<Type, 4> outputTypes;
|
||||
for (Value output : getComputeInstanceOutputValues(instance))
|
||||
outputTypes.push_back(output.getType());
|
||||
return outputTypes;
|
||||
}
|
||||
|
||||
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
|
||||
Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return compute.getBody().front();
|
||||
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||
|
||||
+4
-2
@@ -26,8 +26,10 @@ size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||
|
||||
Reference in New Issue
Block a user