huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
@@ -102,23 +102,6 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
|
||||
let summary = "Apply the same lane-local region to many independent tensors";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -156,22 +139,25 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
|
||||
let summary = "Send multiple tensors through logical channels";
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
||||
let summary = "Receive multiple tensors from logical channels";
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
@@ -180,11 +166,14 @@ def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
@@ -201,18 +190,21 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> {
|
||||
let summary = "Send multiple per-lane tensors through logical channels in a batch body";
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
@@ -232,8 +224,8 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||
let summary = "Receive multiple per-lane tensors through logical channels in a batch body";
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
@@ -242,11 +234,14 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs)) {
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
@@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatMapOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printer.printType(getInputs().front().getType());
|
||||
printer << " -> ";
|
||||
printer.printType(getOutputs().front().getType());
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
Type inputType;
|
||||
Type outputType;
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (inputs.empty())
|
||||
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||
|| parser.parseArrow() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||
if (regionArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
Region* body = result.addRegion();
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
@@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parseCompressedOperandSequence(parser, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getResultTypes());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -105,26 +105,28 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.size() != valueCount)
|
||||
return op->emitError("channel metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
if (channelIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
Type firstType = types.front();
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match the number of values times parent laneCount");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
@@ -31,6 +32,8 @@ namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
@@ -719,11 +722,12 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (oldNodeCount >= 200)
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
@@ -737,7 +741,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
@@ -755,7 +759,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
@@ -764,7 +768,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
@@ -776,7 +780,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
@@ -786,7 +790,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
|
||||
@@ -59,11 +59,7 @@ struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) {
|
||||
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
|
||||
}
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
|
||||
return a == b;
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
@@ -38,9 +38,11 @@ void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += t
|
||||
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
|
||||
if (!logProgress)
|
||||
return;
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
||||
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
|
||||
llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
||||
totalTasks,
|
||||
readyCount,
|
||||
maxCpuCount,
|
||||
xbarsCapacity);
|
||||
}
|
||||
|
||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
||||
@@ -72,18 +74,17 @@ void DcpProgressLogger::printProgress(
|
||||
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
||||
|
||||
bool done = completedTasks == totalTasks;
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
||||
completedTasks,
|
||||
totalTasks,
|
||||
percent,
|
||||
readyCount,
|
||||
cpuCount,
|
||||
maxCpuCount,
|
||||
xbarsUsed,
|
||||
xbarsAvailable,
|
||||
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
||||
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
||||
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
||||
completedTasks,
|
||||
totalTasks,
|
||||
percent,
|
||||
readyCount,
|
||||
cpuCount,
|
||||
maxCpuCount,
|
||||
xbarsUsed,
|
||||
xbarsAvailable,
|
||||
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
||||
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
||||
lastProgressPrint = now;
|
||||
}
|
||||
|
||||
@@ -100,9 +101,7 @@ void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
|
||||
|
||||
#endif
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu) {
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu) {
|
||||
static int dumpIndex = 0;
|
||||
std::string outputDir = onnx_mlir::getOutputDir();
|
||||
if (outputDir.empty())
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "Task.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
// Uncomment to enable DCP progress logging and per-phase profiling during
|
||||
// development. When disabled the logger methods are no-ops and the helpers
|
||||
// compile away.
|
||||
// Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase
|
||||
// profiling. In normal builds the logger methods are no-ops and helpers compile
|
||||
// away.
|
||||
#define DCP_DEBUG_ENABLED
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
@@ -33,10 +33,11 @@ public:
|
||||
|
||||
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
|
||||
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
||||
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
|
||||
size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
||||
void
|
||||
printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
|
||||
private:
|
||||
static std::string formatDuration(double seconds);
|
||||
|
||||
@@ -51,8 +52,6 @@ private:
|
||||
#endif
|
||||
};
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu);
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu);
|
||||
|
||||
} // namespace dcp_graph
|
||||
|
||||
@@ -149,14 +149,6 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
@@ -245,312 +237,6 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<Operation*>& opsToErase,
|
||||
int64_t& nextChannelId) {
|
||||
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
|
||||
|
||||
for (auto batch : batches) {
|
||||
if (batch.getInputs().empty() && batch.getResults().empty())
|
||||
continue;
|
||||
|
||||
if (batch.getInputs().size() != static_cast<size_t>(batch.getLaneCount()))
|
||||
continue;
|
||||
if (batch.getResults().size() != static_cast<size_t>(batch.getLaneCount()))
|
||||
continue;
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> inputReceives;
|
||||
inputReceives.reserve(batch.getInputs().size());
|
||||
bool allInputsAreReceives = true;
|
||||
for (Value input : batch.getInputs()) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (!receiveOp) {
|
||||
allInputsAreReceives = false;
|
||||
break;
|
||||
}
|
||||
inputReceives.push_back(receiveOp);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> resultSends;
|
||||
resultSends.reserve(batch.getResults().size());
|
||||
bool allResultsAreSingleSends = true;
|
||||
for (Value result : batch.getResults()) {
|
||||
if (!result.hasOneUse()) {
|
||||
allResultsAreSingleSends = false;
|
||||
break;
|
||||
}
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(*result.getUsers().begin());
|
||||
if (!sendOp) {
|
||||
allResultsAreSingleSends = false;
|
||||
break;
|
||||
}
|
||||
resultSends.push_back(sendOp);
|
||||
}
|
||||
|
||||
if (!allInputsAreReceives || !allResultsAreSingleSends)
|
||||
continue;
|
||||
|
||||
Block& oldBlock = batch.getBody().front();
|
||||
if (oldBlock.getNumArguments() != 1)
|
||||
continue;
|
||||
|
||||
SmallVector<Value> newWeights(batch.getWeights().begin(), batch.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(batch);
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
batch.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(batch.getLaneCount()),
|
||||
ValueRange(newWeights),
|
||||
ValueRange {});
|
||||
newBatch.getProperties().setOperandSegmentSizes({static_cast<int>(newWeights.size()), 0});
|
||||
|
||||
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
|
||||
if (!coreIds.empty())
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
|
||||
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> receiveEntries;
|
||||
receiveEntries.reserve(inputReceives.size());
|
||||
for (auto receiveOp : inputReceives)
|
||||
receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()});
|
||||
llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> receiveChannelIds;
|
||||
SmallVector<int32_t> receiveSourceCoreIds;
|
||||
SmallVector<int32_t> receiveTargetCoreIds;
|
||||
receiveChannelIds.reserve(receiveEntries.size());
|
||||
receiveSourceCoreIds.reserve(receiveEntries.size());
|
||||
receiveTargetCoreIds.reserve(receiveEntries.size());
|
||||
for (const BatchReceiveEntry& entry : receiveEntries) {
|
||||
(void) entry;
|
||||
receiveChannelIds.push_back(nextChannelId++);
|
||||
receiveSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
receiveTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
batch.getLoc(),
|
||||
oldBlock.getArgument(0).getType(),
|
||||
rewriter.getDenseI64ArrayAttr(receiveChannelIds),
|
||||
rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds));
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(oldBlock.getArgument(0), batchReceive.getOutput());
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (&op == oldYield)
|
||||
continue;
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
Value sendInput = mapper.lookup(oldYield.getOperand(0));
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> sendEntries;
|
||||
sendEntries.reserve(resultSends.size());
|
||||
for (auto sendOp : resultSends)
|
||||
sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()});
|
||||
llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> sendChannelIds;
|
||||
SmallVector<int32_t> sendSourceCoreIds;
|
||||
SmallVector<int32_t> sendTargetCoreIds;
|
||||
sendChannelIds.reserve(sendEntries.size());
|
||||
sendSourceCoreIds.reserve(sendEntries.size());
|
||||
sendTargetCoreIds.reserve(sendEntries.size());
|
||||
for (const BatchSendEntry& entry : sendEntries) {
|
||||
(void) entry;
|
||||
sendChannelIds.push_back(nextChannelId++);
|
||||
sendSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
sendTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
batch.getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(sendChannelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sendSourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(sendTargetCoreIds),
|
||||
sendInput);
|
||||
spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {});
|
||||
|
||||
for (auto receiveOp : inputReceives)
|
||||
opsToErase.push_back(receiveOp);
|
||||
for (auto sendOp : resultSends)
|
||||
opsToErase.push_back(sendOp);
|
||||
opsToErase.push_back(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
SmallVector<Operation*> opsToErase;
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<unsigned> keptInputIndices;
|
||||
SmallVector<unsigned> keptResultIndices;
|
||||
SmallVector<spatial::SpatChannelReceiveOp> internalizedReceives(compute.getInputs().size());
|
||||
SmallVector<SmallVector<spatial::SpatChannelSendOp>> resultSendOps(compute.getNumResults());
|
||||
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (!receiveOp) {
|
||||
keptInputIndices.push_back(inputIndex);
|
||||
continue;
|
||||
}
|
||||
|
||||
internalizedReceives[inputIndex] = receiveOp;
|
||||
opsToErase.push_back(receiveOp);
|
||||
needsRewrite = true;
|
||||
}
|
||||
|
||||
for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) {
|
||||
bool hasNonSendUser = false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||
resultSendOps[resultIndex].push_back(sendOp);
|
||||
opsToErase.push_back(sendOp);
|
||||
needsRewrite = true;
|
||||
continue;
|
||||
}
|
||||
hasNonSendUser = true;
|
||||
}
|
||||
|
||||
if (hasNonSendUser || resultSendOps[resultIndex].empty())
|
||||
keptResultIndices.push_back(resultIndex);
|
||||
}
|
||||
|
||||
if (!needsRewrite)
|
||||
continue;
|
||||
|
||||
SmallVector<Value> newOperands;
|
||||
SmallVector<Type> newResultTypes;
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newOperands.reserve(compute.getNumOperands());
|
||||
newResultTypes.reserve(keptResultIndices.size());
|
||||
newBlockArgTypes.reserve(keptInputIndices.size());
|
||||
newBlockArgLocs.reserve(keptInputIndices.size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
newOperands.push_back(weight);
|
||||
for (unsigned inputIndex : keptInputIndices) {
|
||||
Value input = compute.getInputs()[inputIndex];
|
||||
newOperands.push_back(input);
|
||||
newBlockArgTypes.push_back(input.getType());
|
||||
newBlockArgLocs.push_back(compute.getLoc());
|
||||
}
|
||||
for (unsigned resultIndex : keptResultIndices)
|
||||
newResultTypes.push_back(compute.getResult(resultIndex).getType());
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
auto newCompute =
|
||||
SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(keptInputIndices.size())});
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr);
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices))
|
||||
mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex));
|
||||
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) {
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getResult().getType(),
|
||||
receiveOp.getChannelIdAttr(),
|
||||
receiveOp.getSourceCoreIdAttr(),
|
||||
receiveOp.getTargetCoreIdAttr());
|
||||
mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult());
|
||||
}
|
||||
|
||||
auto oldYieldOp = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
if (&op == oldYieldOp)
|
||||
continue;
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) {
|
||||
if (sendOps.empty())
|
||||
continue;
|
||||
|
||||
Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex));
|
||||
for (auto sendOp : sendOps)
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
sendOp.getChannelIdAttr(),
|
||||
sendOp.getSourceCoreIdAttr(),
|
||||
sendOp.getTargetCoreIdAttr(),
|
||||
yieldedValue);
|
||||
}
|
||||
|
||||
SmallVector<Value> keptYieldOperands;
|
||||
keptYieldOperands.reserve(keptResultIndices.size());
|
||||
for (unsigned resultIndex : keptResultIndices)
|
||||
keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex)));
|
||||
spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands));
|
||||
|
||||
for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices))
|
||||
compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex));
|
||||
|
||||
opsToErase.push_back(compute);
|
||||
}
|
||||
|
||||
sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(opsToErase.begin(), opsToErase.end());
|
||||
while (!pendingRemovals.empty()) {
|
||||
bool erasedAny = false;
|
||||
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
||||
if (!(*it)->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(*it);
|
||||
it = pendingRemovals.erase(it);
|
||||
erasedAny = true;
|
||||
}
|
||||
|
||||
if (erasedAny)
|
||||
continue;
|
||||
|
||||
for (Operation* op : pendingRemovals)
|
||||
op->emitError("failed to sink channel op into compute");
|
||||
llvm_unreachable("channel sinking left cyclic top-level dependencies");
|
||||
}
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
@@ -1280,7 +966,8 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
mergeTriviallyConnectedComputes(getOperation());
|
||||
emitMotifProfile(getOperation());
|
||||
if (std::getenv("DCP_MOTIF_PROFILE"))
|
||||
emitMotifProfile(getOperation());
|
||||
|
||||
func::FuncOp func = getOperation();
|
||||
Location loc = func.getLoc();
|
||||
@@ -1718,17 +1405,12 @@ public:
|
||||
for (Operation* user : result.getUsers())
|
||||
remainingUsers.push_back(user);
|
||||
if (!remainingUsers.empty()) {
|
||||
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n";
|
||||
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n";
|
||||
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
|
||||
for (Operation* user : remainingUsers) {
|
||||
llvm::errs() << " user: " << user->getName()
|
||||
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n";
|
||||
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
diagnostic.attachNote(user->getLoc())
|
||||
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
|
||||
}
|
||||
op->emitOpError("still has uses during per-cpu merge cleanup");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
@@ -40,6 +41,176 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
static Value
|
||||
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(chunkType.getRank());
|
||||
sizes.reserve(chunkType.getRank());
|
||||
strides.reserve(chunkType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
if (values.empty())
|
||||
return false;
|
||||
|
||||
auto firstResult = dyn_cast<OpResult>(values.front());
|
||||
if (!firstResult)
|
||||
return false;
|
||||
|
||||
owner = firstResult.getOwner();
|
||||
startIndex = firstResult.getResultNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc))
|
||||
return packedSlice;
|
||||
|
||||
Operation* owner = nullptr;
|
||||
unsigned startIndex = 0;
|
||||
if (getContiguousOpResults(values, owner, startIndex))
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
return createPackedExtractRowsSlice(
|
||||
extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(values.front().getType());
|
||||
if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0)
|
||||
return {};
|
||||
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
||||
return {};
|
||||
|
||||
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||
&& lhs.resultType == rhs.resultType;
|
||||
@@ -89,45 +260,97 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
|
||||
auto* block = rewriter.createBlock(
|
||||
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(anchorChunk.input, block->getArgument(0));
|
||||
|
||||
for (Operation* op : anchorChunk.ops) {
|
||||
Operation* cloned = rewriter.clone(*op, mapping);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||
mapping.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
inputs.reserve(run.size());
|
||||
outputTypes.reserve(run.size());
|
||||
for (const RegularChunk& chunk : run) {
|
||||
for (const RegularChunk& chunk : run)
|
||||
inputs.push_back(chunk.input);
|
||||
outputTypes.push_back(chunk.output.getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
auto mapOp =
|
||||
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
|
||||
buildRegularMapBody(mapOp, anchorChunk, rewriter);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||
if (!packedInput)
|
||||
return;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
|
||||
auto packedInit = tensor::EmptyOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block* loopBlock = loop.getBody();
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
Value iv = loopBlock->getArgument(0);
|
||||
Value acc = loopBlock->getArgument(1);
|
||||
|
||||
Value inputRowOffset = iv;
|
||||
if (inputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
|
||||
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets;
|
||||
SmallVector<OpFoldResult> extractSizes;
|
||||
SmallVector<OpFoldResult> extractStrides;
|
||||
extractOffsets.push_back(inputRowOffset);
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
auto inputSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(anchorChunk.input, inputSlice.getResult());
|
||||
for (Operation* op : anchorChunk.ops) {
|
||||
Operation* cloned = rewriter.clone(*op, mapping);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||
mapping.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
Value mappedOutput = mapping.lookup(anchorChunk.output);
|
||||
Value outputRowOffset = iv;
|
||||
if (outputType.getDimSize(0) != 1) {
|
||||
auto rowsPerValue =
|
||||
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
|
||||
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes;
|
||||
SmallVector<OpFoldResult> insertStrides;
|
||||
insertOffsets.push_back(outputRowOffset);
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
||||
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
||||
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(mapOp.getResult(index));
|
||||
output.replaceAllUsesWith(replacement);
|
||||
}
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
@@ -178,28 +401,29 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
outputTypes.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
outputTypes.push_back(entry.op.getOutput().getType());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -255,17 +479,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,25 +524,25 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
outputTypes.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
outputTypes.push_back(op.getOutput().getType());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -352,17 +579,20 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user