huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled

remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
NiccoloN
2026-05-12 10:35:44 +02:00
parent feaff820e1
commit 909c4acfdd
84 changed files with 4048 additions and 3310 deletions
+2 -2
View File
@@ -1,8 +1,8 @@
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
using namespace mlir;
namespace onnx_mlir::spatial {
+28 -33
View File
@@ -102,23 +102,6 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
let summary = "Apply the same lane-local region to many independent tensors";
let arguments = (ins
Variadic<SpatTensor>:$inputs
);
let results = (outs
Variadic<SpatTensor>:$outputs
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
@@ -156,22 +139,25 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
}];
}
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
let summary = "Send multiple tensors through logical channels";
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
}
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
let summary = "Receive multiple tensors from logical channels";
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
@@ -180,11 +166,14 @@ def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
);
let results = (outs
Variadic<SpatTensor>:$outputs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
@@ -201,18 +190,21 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> {
let summary = "Send multiple per-lane tensors through logical channels in a batch body";
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<SpatTensor>:$inputs
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
@@ -232,8 +224,8 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
let summary = "Receive multiple per-lane tensors through logical channels in a batch body";
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
@@ -242,11 +234,14 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
);
let results = (outs
Variadic<SpatTensor>:$outputs
SpatTensor:$output
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
}
//===----------------------------------------------------------------------===//
+2 -225
View File
@@ -7,8 +7,8 @@
#include <string>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
return failure();
if (parseCompressedOperandSequence(parser, inputs)) {
if (parseCompressedOperandSequence(parser, inputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
@@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success();
}
void SpatMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
@@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
@@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
return success();
}
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getResultTypes());
}
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputTypes);
return success();
}
} // namespace spatial
} // namespace onnx_mlir
+67 -76
View File
@@ -105,26 +105,28 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return batchOp.getLaneCount();
}
static LogicalResult verifyManyChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
size_t valueCount) {
static LogicalResult verifyTensorChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.size() != valueCount)
return op->emitError("channel metadata length must match the number of values");
return success();
}
if (channelIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
Type firstType = types.front();
for (Type type : types.drop_front())
if (type != firstType)
return op->emitError() << kind << " values must all have the same type";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success();
}
@@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
return success();
}
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
size_t valueCount) {
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match the number of values times parent laneCount");
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
@@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() {
return success();
}
LogicalResult SpatMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
return emitError("body yield type must match output type");
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
@@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() {
return success();
}
LogicalResult SpatChannelSendManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveManyOp::verify() {
if (failed(verifyManyChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelSendManyBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
}
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
if (failed(verifyManyBatchChannelSizes(
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
return failure();
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
"channel_receive_tensor_batch");
}
LogicalResult SpatComputeBatch::verify() {
@@ -10,6 +10,7 @@
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdlib>
#include <iterator>
#include <numeric>
#include <optional>
@@ -31,6 +32,8 @@ namespace {
using SpatCompute = onnx_mlir::spatial::SpatCompute;
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
@@ -719,11 +722,12 @@ DCPAnalysisResult DCPAnalysis::run() {
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) {
if (oldNodeCount >= 200)
if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
@@ -737,7 +741,7 @@ DCPAnalysisResult DCPAnalysis::run() {
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
@@ -755,7 +759,7 @@ DCPAnalysisResult DCPAnalysis::run() {
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
if (virtualGraph.nodes.size() >= 200)
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
@@ -764,7 +768,7 @@ DCPAnalysisResult DCPAnalysis::run() {
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (virtualGraph.nodes.size() >= 200)
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
@@ -776,7 +780,7 @@ DCPAnalysisResult DCPAnalysis::run() {
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (virtualGraph.nodes.size() >= 200)
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
@@ -786,7 +790,7 @@ DCPAnalysisResult DCPAnalysis::run() {
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (virtualGraph.nodes.size() >= 200)
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
@@ -59,11 +59,7 @@ struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const ComputeInstance& v) {
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
}
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
return a == b;
}
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
};
} // namespace llvm
@@ -38,9 +38,11 @@ void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += t
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
if (!logProgress)
return;
llvm::errs() << llvm::formatv(
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
totalTasks,
readyCount,
maxCpuCount,
xbarsCapacity);
}
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
@@ -72,18 +74,17 @@ void DcpProgressLogger::printProgress(
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
bool done = completedTasks == totalTasks;
llvm::errs() << llvm::formatv(
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
maxCpuCount,
xbarsUsed,
xbarsAvailable,
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
maxCpuCount,
xbarsUsed,
xbarsAvailable,
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
lastProgressPrint = now;
}
@@ -100,9 +101,7 @@ void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
#endif
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu) {
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu) {
static int dumpIndex = 0;
std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty())
@@ -9,9 +9,9 @@
#include "Task.hpp"
#include "Utils.hpp"
// Uncomment to enable DCP progress logging and per-phase profiling during
// development. When disabled the logger methods are no-ops and the helpers
// compile away.
// Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase
// profiling. In normal builds the logger methods are no-ops and helpers compile
// away.
#define DCP_DEBUG_ENABLED
#ifdef DCP_DEBUG_ENABLED
@@ -33,10 +33,11 @@ public:
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
size_t xbarsUsed, size_t xbarsAvailable, bool force);
void
printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force);
#ifdef DCP_DEBUG_ENABLED
private:
static std::string formatDuration(double seconds);
@@ -51,8 +52,6 @@ private:
#endif
};
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu);
void dumpGraphDot(const std::vector<TaskDCP>& nodes, const std::vector<std::list<TaskDCP*>>& cpuTasks, CPU lastCpu);
} // namespace dcp_graph
@@ -149,14 +149,6 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
return coreIds;
}
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
return {};
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
@@ -245,312 +237,6 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
}
static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
IRRewriter& rewriter,
SmallVectorImpl<Operation*>& opsToErase,
int64_t& nextChannelId) {
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batch : batches) {
if (batch.getInputs().empty() && batch.getResults().empty())
continue;
if (batch.getInputs().size() != static_cast<size_t>(batch.getLaneCount()))
continue;
if (batch.getResults().size() != static_cast<size_t>(batch.getLaneCount()))
continue;
SmallVector<spatial::SpatChannelReceiveOp> inputReceives;
inputReceives.reserve(batch.getInputs().size());
bool allInputsAreReceives = true;
for (Value input : batch.getInputs()) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (!receiveOp) {
allInputsAreReceives = false;
break;
}
inputReceives.push_back(receiveOp);
}
SmallVector<spatial::SpatChannelSendOp> resultSends;
resultSends.reserve(batch.getResults().size());
bool allResultsAreSingleSends = true;
for (Value result : batch.getResults()) {
if (!result.hasOneUse()) {
allResultsAreSingleSends = false;
break;
}
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(*result.getUsers().begin());
if (!sendOp) {
allResultsAreSingleSends = false;
break;
}
resultSends.push_back(sendOp);
}
if (!allInputsAreReceives || !allResultsAreSingleSends)
continue;
Block& oldBlock = batch.getBody().front();
if (oldBlock.getNumArguments() != 1)
continue;
SmallVector<Value> newWeights(batch.getWeights().begin(), batch.getWeights().end());
rewriter.setInsertionPointAfter(batch);
auto newBatch = SpatComputeBatch::create(rewriter,
batch.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batch.getLaneCount()),
ValueRange(newWeights),
ValueRange {});
newBatch.getProperties().setOperandSegmentSizes({static_cast<int>(newWeights.size()), 0});
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
if (!coreIds.empty())
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
auto* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
rewriter.setInsertionPointToStart(newBlock);
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> receiveEntries;
receiveEntries.reserve(inputReceives.size());
for (auto receiveOp : inputReceives)
receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()});
llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> receiveChannelIds;
SmallVector<int32_t> receiveSourceCoreIds;
SmallVector<int32_t> receiveTargetCoreIds;
receiveChannelIds.reserve(receiveEntries.size());
receiveSourceCoreIds.reserve(receiveEntries.size());
receiveTargetCoreIds.reserve(receiveEntries.size());
for (const BatchReceiveEntry& entry : receiveEntries) {
(void) entry;
receiveChannelIds.push_back(nextChannelId++);
receiveSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
receiveTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
batch.getLoc(),
oldBlock.getArgument(0).getType(),
rewriter.getDenseI64ArrayAttr(receiveChannelIds),
rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds),
rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds));
IRMapping mapper;
mapper.map(oldBlock.getArgument(0), batchReceive.getOutput());
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (&op == oldYield)
continue;
rewriter.clone(op, mapper);
}
Value sendInput = mapper.lookup(oldYield.getOperand(0));
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> sendEntries;
sendEntries.reserve(resultSends.size());
for (auto sendOp : resultSends)
sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()});
llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> sendChannelIds;
SmallVector<int32_t> sendSourceCoreIds;
SmallVector<int32_t> sendTargetCoreIds;
sendChannelIds.reserve(sendEntries.size());
sendSourceCoreIds.reserve(sendEntries.size());
sendTargetCoreIds.reserve(sendEntries.size());
for (const BatchSendEntry& entry : sendEntries) {
(void) entry;
sendChannelIds.push_back(nextChannelId++);
sendSourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
sendTargetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
spatial::SpatChannelSendBatchOp::create(rewriter,
batch.getLoc(),
rewriter.getDenseI64ArrayAttr(sendChannelIds),
rewriter.getDenseI32ArrayAttr(sendSourceCoreIds),
rewriter.getDenseI32ArrayAttr(sendTargetCoreIds),
sendInput);
spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {});
for (auto receiveOp : inputReceives)
opsToErase.push_back(receiveOp);
for (auto sendOp : resultSends)
opsToErase.push_back(sendOp);
opsToErase.push_back(batch);
}
}
void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
SmallVector<Operation*> opsToErase;
for (auto compute : computes) {
SmallVector<unsigned> keptInputIndices;
SmallVector<unsigned> keptResultIndices;
SmallVector<spatial::SpatChannelReceiveOp> internalizedReceives(compute.getInputs().size());
SmallVector<SmallVector<spatial::SpatChannelSendOp>> resultSendOps(compute.getNumResults());
bool needsRewrite = false;
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (!receiveOp) {
keptInputIndices.push_back(inputIndex);
continue;
}
internalizedReceives[inputIndex] = receiveOp;
opsToErase.push_back(receiveOp);
needsRewrite = true;
}
for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) {
bool hasNonSendUser = false;
for (Operation* user : result.getUsers()) {
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(user)) {
resultSendOps[resultIndex].push_back(sendOp);
opsToErase.push_back(sendOp);
needsRewrite = true;
continue;
}
hasNonSendUser = true;
}
if (hasNonSendUser || resultSendOps[resultIndex].empty())
keptResultIndices.push_back(resultIndex);
}
if (!needsRewrite)
continue;
SmallVector<Value> newOperands;
SmallVector<Type> newResultTypes;
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newOperands.reserve(compute.getNumOperands());
newResultTypes.reserve(keptResultIndices.size());
newBlockArgTypes.reserve(keptInputIndices.size());
newBlockArgLocs.reserve(keptInputIndices.size());
for (Value weight : compute.getWeights())
newOperands.push_back(weight);
for (unsigned inputIndex : keptInputIndices) {
Value input = compute.getInputs()[inputIndex];
newOperands.push_back(input);
newBlockArgTypes.push_back(input.getType());
newBlockArgLocs.push_back(compute.getLoc());
}
for (unsigned resultIndex : keptResultIndices)
newResultTypes.push_back(compute.getResult(resultIndex).getType());
rewriter.setInsertionPointAfter(compute);
auto newCompute =
SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands));
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(keptInputIndices.size())});
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
IRMapping mapper;
for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices))
mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex));
rewriter.setInsertionPointToStart(newBlock);
for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) {
if (!receiveOp)
continue;
auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getResult().getType(),
receiveOp.getChannelIdAttr(),
receiveOp.getSourceCoreIdAttr(),
receiveOp.getTargetCoreIdAttr());
mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult());
}
auto oldYieldOp = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : compute.getBody().front()) {
if (&op == oldYieldOp)
continue;
rewriter.clone(op, mapper);
}
for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) {
if (sendOps.empty())
continue;
Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex));
for (auto sendOp : sendOps)
spatial::SpatChannelSendOp::create(rewriter,
sendOp.getLoc(),
sendOp.getChannelIdAttr(),
sendOp.getSourceCoreIdAttr(),
sendOp.getTargetCoreIdAttr(),
yieldedValue);
}
SmallVector<Value> keptYieldOperands;
keptYieldOperands.reserve(keptResultIndices.size());
for (unsigned resultIndex : keptResultIndices)
keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex)));
spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands));
for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices))
compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex));
opsToErase.push_back(compute);
}
sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId);
SmallVector<Operation*> pendingRemovals(opsToErase.begin(), opsToErase.end());
while (!pendingRemovals.empty()) {
bool erasedAny = false;
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
if (!(*it)->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(*it);
it = pendingRemovals.erase(it);
erasedAny = true;
}
if (erasedAny)
continue;
for (Operation* op : pendingRemovals)
op->emitError("failed to sink channel op into compute");
llvm_unreachable("channel sinking left cyclic top-level dependencies");
}
}
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
@@ -1280,7 +966,8 @@ public:
void runOnOperation() override {
mergeTriviallyConnectedComputes(getOperation());
emitMotifProfile(getOperation());
if (std::getenv("DCP_MOTIF_PROFILE"))
emitMotifProfile(getOperation());
func::FuncOp func = getOperation();
Location loc = func.getLoc();
@@ -1718,17 +1405,12 @@ public:
for (Operation* user : result.getUsers())
remainingUsers.push_back(user);
if (!remainingUsers.empty()) {
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n";
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n";
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
for (Operation* user : remainingUsers) {
llvm::errs() << " user: " << user->getName()
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n";
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
diagnostic.attachNote(user->getLoc())
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
}
op->emitOpError("still has uses during per-cpu merge cleanup");
signalPassFailure();
return;
}
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
@@ -40,6 +41,176 @@ struct RegularChunk {
Value output;
};
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
packedShape[0] *= count;
return RankedTensorType::get(packedShape, elementType.getElementType());
}
static Value
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
if (values.empty())
return {};
if (values.size() == 1)
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
return {};
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
return {};
}
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
if (values.empty())
return false;
auto firstResult = dyn_cast<OpResult>(values.front());
if (!firstResult)
return false;
owner = firstResult.getOwner();
startIndex = firstResult.getResultNumber();
for (auto [index, value] : llvm::enumerate(values)) {
auto result = dyn_cast<OpResult>(value);
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
return false;
}
return true;
}
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
if (values.empty())
return {};
if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc))
return packedSlice;
Operation* owner = nullptr;
unsigned startIndex = 0;
if (getContiguousOpResults(values, owner, startIndex))
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
return createPackedExtractRowsSlice(
extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
auto firstType = dyn_cast<RankedTensorType>(values.front().getType());
if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0)
return {};
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
return {};
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
}
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType;
@@ -89,45 +260,97 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
return chunk;
}
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
auto* block = rewriter.createBlock(
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
rewriter.setInsertionPointToEnd(block);
IRMapping mapping;
mapping.map(anchorChunk.input, block->getArgument(0));
for (Operation* op : anchorChunk.ops) {
Operation* cloned = rewriter.clone(*op, mapping);
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
mapping.map(oldResult, newResult);
}
spatial::SpatYieldOp::create(
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
}
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front();
SmallVector<Value> inputs;
SmallVector<Type> outputTypes;
inputs.reserve(run.size());
outputTypes.reserve(run.size());
for (const RegularChunk& chunk : run) {
for (const RegularChunk& chunk : run)
inputs.push_back(chunk.input);
outputTypes.push_back(chunk.output.getType());
}
rewriter.setInsertionPoint(anchorChunk.startOp);
auto mapOp =
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
buildRegularMapBody(mapOp, anchorChunk, rewriter);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
if (!packedInput)
return;
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
auto packedInit = tensor::EmptyOp::create(
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
auto loop =
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{
OpBuilder::InsertionGuard guard(rewriter);
Block* loopBlock = loop.getBody();
rewriter.setInsertionPointToStart(loopBlock);
Value iv = loopBlock->getArgument(0);
Value acc = loopBlock->getArgument(1);
Value inputRowOffset = iv;
if (inputType.getDimSize(0) != 1) {
auto rowsPerValue =
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.push_back(inputRowOffset);
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0)));
extractStrides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
extractOffsets.push_back(rewriter.getIndexAttr(0));
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto inputSlice = tensor::ExtractSliceOp::create(
rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides);
IRMapping mapping;
mapping.map(anchorChunk.input, inputSlice.getResult());
for (Operation* op : anchorChunk.ops) {
Operation* cloned = rewriter.clone(*op, mapping);
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
mapping.map(oldResult, newResult);
}
Value mappedOutput = mapping.lookup(anchorChunk.output);
Value outputRowOffset = iv;
if (outputType.getDimSize(0) != 1) {
auto rowsPerValue =
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
}
SmallVector<OpFoldResult> insertOffsets;
SmallVector<OpFoldResult> insertSizes;
SmallVector<OpFoldResult> insertStrides;
insertOffsets.push_back(outputRowOffset);
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0)));
insertStrides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
insertOffsets.push_back(rewriter.getIndexAttr(0));
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
insertStrides.push_back(rewriter.getIndexAttr(1));
}
auto inserted = tensor::InsertSliceOp::create(
rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
}
for (auto [index, chunk] : llvm::enumerate(run)) {
Value replacement = extractPackedChunk(
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
Value output = chunk.output;
output.replaceAllUsesWith(mapOp.getResult(index));
output.replaceAllUsesWith(replacement);
}
SmallVector<Operation*> opsToErase;
@@ -178,28 +401,29 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
channelIds.reserve(sortedEntries.size());
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
outputTypes.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
outputTypes.push_back(entry.op.getOutput().getType());
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
rewriter.setInsertionPoint(run.front());
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter,
run.front().getLoc(),
packedType,
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
for (auto op : run)
rewriter.eraseOp(op);
@@ -255,17 +479,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput);
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
it = runIt;
continue;
}
}
}
@@ -297,25 +524,25 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
SmallVector<Type> outputTypes;
outputTypes.reserve(run.size());
for (auto op : run) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
outputTypes.push_back(op.getOutput().getType());
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
run.front().getLoc(),
TypeRange(outputTypes),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
run.front().getLoc(),
packedType,
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
for (auto op : run)
rewriter.eraseOp(op);
@@ -352,17 +579,20 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
}
rewriter.setInsertionPoint(run.front());
spatial::SpatChannelSendManyBatchOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(inputs));
for (auto op : run)
rewriter.eraseOp(op);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput);
for (auto op : run)
rewriter.eraseOp(op);
it = runIt;
continue;
it = runIt;
continue;
}
}
}