better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -9,8 +9,6 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
|
||||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
|
||||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||||
|
|||||||
+291
-35
@@ -1,5 +1,6 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
@@ -557,22 +558,6 @@ Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t v
|
|||||||
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder);
|
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
|
||||||
SmallVector<Value, 8> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (int64_t value : values)
|
|
||||||
constants.push_back(createIndexConstant(state, anchor, value));
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
|
|
||||||
SmallVector<int64_t, 8> widened;
|
|
||||||
widened.reserve(values.size());
|
|
||||||
for (int32_t value : values)
|
|
||||||
widened.push_back(value);
|
|
||||||
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||||
SmallVector<APInt, 8> elements;
|
SmallVector<APInt, 8> elements;
|
||||||
elements.reserve(values.size());
|
elements.reserve(values.size());
|
||||||
@@ -644,6 +629,28 @@ Value createLaneIndexedIndexValue(MaterializerState& state,
|
|||||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createIndexedIndexValue(
|
||||||
|
MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values, Value index, Location loc) {
|
||||||
|
assert(!values.empty() && "expected at least one indexed value");
|
||||||
|
|
||||||
|
if (allEqual(values))
|
||||||
|
return createIndexConstant(state, anchor, values.front());
|
||||||
|
|
||||||
|
Value table = createIndexTensorConstant(state, anchor, values);
|
||||||
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createIndexedIndexValue(
|
||||||
|
MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values, Value index, Location loc) {
|
||||||
|
assert(!values.empty() && "expected at least one indexed value");
|
||||||
|
|
||||||
|
if (allEqual(values))
|
||||||
|
return createIndexConstant(state, anchor, values.front());
|
||||||
|
|
||||||
|
Value table = createIndexTensorConstant(state, anchor, values);
|
||||||
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<ComputeInstance, 8>>
|
FailureOr<SmallVector<ComputeInstance, 8>>
|
||||||
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
||||||
SmallVector<ComputeInstance, 8> peers;
|
SmallVector<ComputeInstance, 8> peers;
|
||||||
@@ -808,6 +815,53 @@ ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey ke
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void appendScalarSend(MaterializerState& state,
|
||||||
|
MaterializedClass& sourceClass,
|
||||||
|
Value payload,
|
||||||
|
int64_t channelId,
|
||||||
|
int32_t sourceCoreId,
|
||||||
|
int32_t targetCoreId,
|
||||||
|
Location loc) {
|
||||||
|
assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class");
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
|
||||||
|
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
|
||||||
|
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
void appendScalarSendLoop(MaterializerState& state,
|
||||||
|
MaterializedClass& sourceClass,
|
||||||
|
Value payload,
|
||||||
|
ArrayRef<int64_t> channelIds,
|
||||||
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
|
Location loc) {
|
||||||
|
assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class");
|
||||||
|
assert(channelIds.size() > 1 && "send loop is only useful for multiple sends");
|
||||||
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
|
||||||
|
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
|
||||||
|
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||||
|
Value step = createIndexConstant(state, sourceClass.op, 1);
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||||
|
state.rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value index = loop.getInductionVar();
|
||||||
|
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
|
||||||
|
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
|
||||||
|
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
|
||||||
|
|
||||||
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||||
|
}
|
||||||
|
|
||||||
void appendSend(MaterializerState& state,
|
void appendSend(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& sourceClass,
|
||||||
Value payload,
|
Value payload,
|
||||||
@@ -819,9 +873,9 @@ void appendSend(MaterializerState& state,
|
|||||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
assert(!channelIds.empty() && "expected at least one send");
|
assert(!channelIds.empty() && "expected at least one send");
|
||||||
|
|
||||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
||||||
|
|
||||||
if (sourceClass.isBatch) {
|
if (sourceClass.isBatch) {
|
||||||
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
|
||||||
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
|
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
|
||||||
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
|
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
|
||||||
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
|
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
|
||||||
@@ -829,12 +883,13 @@ void appendSend(MaterializerState& state,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
if (channelIds.size() == 1) {
|
||||||
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
|
appendScalarSend(
|
||||||
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
|
state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
|
||||||
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
|
return;
|
||||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value appendScalarReceive(MaterializerState& state,
|
Value appendScalarReceive(MaterializerState& state,
|
||||||
@@ -911,6 +966,212 @@ Value appendPackedScalarReceives(MaterializerState& state,
|
|||||||
return packed;
|
return packed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<int64_t> getConstantIndexValue(Value value) {
|
||||||
|
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||||
|
return constant.value();
|
||||||
|
|
||||||
|
APInt constantValue;
|
||||||
|
if (matchPattern(value, m_ConstantInt(&constantValue)))
|
||||||
|
return constantValue.getSExtValue();
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool getReceiveMetadata(SpatChannelReceiveOp receive,
|
||||||
|
int64_t& channelId,
|
||||||
|
int64_t& sourceCoreId,
|
||||||
|
int64_t& targetCoreId) {
|
||||||
|
// SpatChannelReceiveOp operands are: channel, source core, target core.
|
||||||
|
std::optional<int64_t> channel = getConstantIndexValue(receive->getOperand(0));
|
||||||
|
std::optional<int64_t> source = getConstantIndexValue(receive->getOperand(1));
|
||||||
|
std::optional<int64_t> target = getConstantIndexValue(receive->getOperand(2));
|
||||||
|
if (!channel || !source || !target)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
channelId = *channel;
|
||||||
|
sourceCoreId = *source;
|
||||||
|
targetCoreId = *target;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasCompatibleConcatTypes(RankedTensorType concatType, RankedTensorType fragmentType, size_t fragmentCount) {
|
||||||
|
if (!concatType.hasStaticShape() || !fragmentType.hasStaticShape())
|
||||||
|
return false;
|
||||||
|
if (concatType.getRank() != fragmentType.getRank())
|
||||||
|
return false;
|
||||||
|
if (concatType.getRank() == 0)
|
||||||
|
return false;
|
||||||
|
if (concatType.getElementType() != fragmentType.getElementType())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (concatType.getDimSize(0) != fragmentType.getDimSize(0) * static_cast<int64_t>(fragmentCount))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (int64_t dim = 1; dim < concatType.getRank(); ++dim)
|
||||||
|
if (concatType.getDimSize(dim) != fragmentType.getDimSize(dim))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createReceiveConcatLoop(MaterializerState& state,
|
||||||
|
Operation* anchor,
|
||||||
|
Operation* insertionPoint,
|
||||||
|
RankedTensorType concatType,
|
||||||
|
RankedTensorType fragmentType,
|
||||||
|
ArrayRef<int64_t> channelIds,
|
||||||
|
ArrayRef<int64_t> sourceCoreIds,
|
||||||
|
ArrayRef<int64_t> targetCoreIds,
|
||||||
|
Location loc) {
|
||||||
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||||
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||||
|
assert(!channelIds.empty() && "expected at least one receive");
|
||||||
|
|
||||||
|
Value lowerBound = createIndexConstant(state, anchor, 0);
|
||||||
|
Value upperBound = createIndexConstant(state, anchor, static_cast<int64_t>(channelIds.size()));
|
||||||
|
Value step = createIndexConstant(state, anchor, 1);
|
||||||
|
|
||||||
|
state.rewriter.setInsertionPoint(insertionPoint);
|
||||||
|
Value init =
|
||||||
|
tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult();
|
||||||
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
|
||||||
|
|
||||||
|
Block* body = loop.getBody();
|
||||||
|
if (!body->empty())
|
||||||
|
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
|
||||||
|
state.rewriter.eraseOp(yield);
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||||
|
state.rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
|
Value index = loop.getInductionVar();
|
||||||
|
Value acc = body->getArgument(1);
|
||||||
|
|
||||||
|
Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc);
|
||||||
|
Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc);
|
||||||
|
Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc);
|
||||||
|
|
||||||
|
Value received =
|
||||||
|
SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput();
|
||||||
|
|
||||||
|
Value firstOffset = index;
|
||||||
|
if (fragmentType.getDimSize(0) != 1) {
|
||||||
|
Value rowsPerFragment = createIndexConstant(state, anchor, fragmentType.getDimSize(0));
|
||||||
|
firstOffset = arith::MulIOp::create(state.rewriter, loc, index, rowsPerFragment).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult, 4> offsets;
|
||||||
|
SmallVector<OpFoldResult, 4> sizes;
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
offsets.reserve(fragmentType.getRank());
|
||||||
|
sizes.reserve(fragmentType.getRank());
|
||||||
|
strides.reserve(fragmentType.getRank());
|
||||||
|
|
||||||
|
offsets.push_back(firstOffset);
|
||||||
|
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(0)));
|
||||||
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
for (int64_t dim = 1; dim < fragmentType.getRank(); ++dim) {
|
||||||
|
offsets.push_back(state.rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(dim)));
|
||||||
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value next = tensor::InsertSliceOp::create(state.rewriter, loc, received, acc, offsets, sizes, strides).getResult();
|
||||||
|
scf::YieldOp::create(state.rewriter, loc, next);
|
||||||
|
|
||||||
|
return loop.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compactReceiveConcat(MaterializerState& state, MaterializedClass& materializedClass, tensor::ConcatOp concat) {
|
||||||
|
auto dimAttr = concat->getAttrOfType<IntegerAttr>("dim");
|
||||||
|
if (!dimAttr || dimAttr.getInt() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
OperandRange inputs = concat->getOperands();
|
||||||
|
if (inputs.size() < 2)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<SpatChannelReceiveOp, 8> receives;
|
||||||
|
receives.reserve(inputs.size());
|
||||||
|
|
||||||
|
for (Value input : inputs) {
|
||||||
|
auto receive = input.getDefiningOp<SpatChannelReceiveOp>();
|
||||||
|
if (!receive)
|
||||||
|
return false;
|
||||||
|
if (receive->getBlock() != concat->getBlock())
|
||||||
|
return false;
|
||||||
|
if (!receive->getResult(0).hasOneUse())
|
||||||
|
return false;
|
||||||
|
receives.push_back(receive);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* expected = concat.getOperation();
|
||||||
|
for (SpatChannelReceiveOp receive : llvm::reverse(receives)) {
|
||||||
|
Operation* previous = expected->getPrevNode();
|
||||||
|
if (previous != receive.getOperation())
|
||||||
|
return false;
|
||||||
|
expected = previous;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatType = dyn_cast<RankedTensorType>(concat->getResult(0).getType());
|
||||||
|
auto fragmentType = dyn_cast<RankedTensorType>(receives.front()->getResult(0).getType());
|
||||||
|
if (!concatType || !fragmentType)
|
||||||
|
return false;
|
||||||
|
if (!hasCompatibleConcatTypes(concatType, fragmentType, receives.size()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> channelIds;
|
||||||
|
SmallVector<int64_t, 8> sourceCoreIds;
|
||||||
|
SmallVector<int64_t, 8> targetCoreIds;
|
||||||
|
channelIds.reserve(receives.size());
|
||||||
|
sourceCoreIds.reserve(receives.size());
|
||||||
|
targetCoreIds.reserve(receives.size());
|
||||||
|
|
||||||
|
for (SpatChannelReceiveOp receive : receives) {
|
||||||
|
if (receive->getResult(0).getType() != fragmentType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
int64_t channelId = 0;
|
||||||
|
int64_t sourceCoreId = 0;
|
||||||
|
int64_t targetCoreId = 0;
|
||||||
|
if (!getReceiveMetadata(receive, channelId, sourceCoreId, targetCoreId))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
channelIds.push_back(channelId);
|
||||||
|
sourceCoreIds.push_back(sourceCoreId);
|
||||||
|
targetCoreIds.push_back(targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value replacement = createReceiveConcatLoop(state,
|
||||||
|
materializedClass.op,
|
||||||
|
receives.front().getOperation(),
|
||||||
|
concatType,
|
||||||
|
fragmentType,
|
||||||
|
channelIds,
|
||||||
|
sourceCoreIds,
|
||||||
|
targetCoreIds,
|
||||||
|
concat.getLoc());
|
||||||
|
|
||||||
|
concat->getResult(0).replaceAllUsesWith(replacement);
|
||||||
|
state.rewriter.eraseOp(concat.getOperation());
|
||||||
|
|
||||||
|
for (SpatChannelReceiveOp receive : llvm::reverse(receives))
|
||||||
|
state.rewriter.eraseOp(receive.getOperation());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compactReceiveConcats(MaterializerState& state) {
|
||||||
|
SmallVector<std::pair<MaterializedClass*, tensor::ConcatOp>, 16> concatOps;
|
||||||
|
|
||||||
|
for (MaterializedClass& materializedClass : state.classes)
|
||||||
|
materializedClass.op->walk([&](tensor::ConcatOp concat) { concatOps.push_back({&materializedClass, concat}); });
|
||||||
|
|
||||||
|
for (auto [materializedClass, concat] : concatOps)
|
||||||
|
compactReceiveConcat(state, *materializedClass, concat);
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||||
MaterializedClass& sourceClass,
|
MaterializedClass& sourceClass,
|
||||||
MaterializedClass& targetClass,
|
MaterializedClass& targetClass,
|
||||||
@@ -1086,15 +1347,8 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult emitHostCommunication(MaterializerState& state,
|
LogicalResult
|
||||||
MaterializedClass& sourceClass,
|
emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) {
|
||||||
ArrayRef<ProducerKey> keys,
|
|
||||||
Value payload,
|
|
||||||
Value originalOutput,
|
|
||||||
Location loc) {
|
|
||||||
(void) keys;
|
|
||||||
(void) loc;
|
|
||||||
|
|
||||||
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
@@ -1115,7 +1369,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
if (failed(
|
if (failed(
|
||||||
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
|
||||||
return failure();
|
return failure();
|
||||||
state.availableValues[keys.front()][sourceClass.id] = payload;
|
state.availableValues[keys.front()][sourceClass.id] = payload;
|
||||||
return success();
|
return success();
|
||||||
@@ -1129,7 +1383,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (ProducerKey key : keys)
|
for (ProducerKey key : keys)
|
||||||
@@ -1146,7 +1400,7 @@ FailureOr<Value> materializeWholeBatchInput(
|
|||||||
|
|
||||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
|
|
||||||
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
|
uint32_t batchLaneCount = batch.getLaneCount();
|
||||||
SmallVector<Value, 8> fragments;
|
SmallVector<Value, 8> fragments;
|
||||||
uint32_t lane = 0;
|
uint32_t lane = 0;
|
||||||
|
|
||||||
@@ -1426,6 +1680,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
|||||||
if (failed(materializeInstanceSlot(state, instance)))
|
if (failed(materializeInstanceSlot(state, instance)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
compactReceiveConcats(state);
|
||||||
|
|
||||||
replaceHostUses(state);
|
replaceHostUses(state);
|
||||||
if (failed(eraseOldComputeOps(state)))
|
if (failed(eraseOldComputeOps(state)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -27,7 +27,6 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iterator>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
@@ -35,7 +34,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "MaterializeMergeSchedule.hpp"
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
#include "PostMergeCompaction.hpp"
|
|
||||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
@@ -660,13 +658,6 @@ public:
|
|||||||
|
|
||||||
emitMergeIrCounts("after-materialization", func);
|
emitMergeIrCounts("after-materialization", func);
|
||||||
|
|
||||||
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
emitMergeIrCounts("after-post-merge-compaction", func);*/
|
|
||||||
|
|
||||||
{
|
{
|
||||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
||||||
if (!sortTopologically(&func.getBody().front())) {
|
if (!sortTopologically(&func.getBody().front())) {
|
||||||
|
|||||||
@@ -1,532 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
|
||||||
#include "llvm/ADT/Hashing.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <chrono>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <limits>
|
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
#include "PostMergeCompaction.hpp"
|
|
||||||
#include "RegularOpCompaction.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using SpatCompute = spatial::SpatCompute;
|
|
||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
|
||||||
|
|
||||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
|
||||||
|
|
||||||
class ScopedMergePhaseTimer {
|
|
||||||
public:
|
|
||||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
|
||||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
|
||||||
if (enabled)
|
|
||||||
start = std::chrono::steady_clock::now();
|
|
||||||
}
|
|
||||||
|
|
||||||
~ScopedMergePhaseTimer() {
|
|
||||||
if (!enabled)
|
|
||||||
return;
|
|
||||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
|
||||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
|
||||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool enabled = false;
|
|
||||||
std::string phase;
|
|
||||||
std::chrono::steady_clock::time_point start;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
|
||||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
|
||||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
|
||||||
|
|
||||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
|
||||||
APInt constantValue;
|
|
||||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return failure();
|
|
||||||
return constantValue.getSExtValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
|
||||||
APInt constantValue;
|
|
||||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return failure();
|
|
||||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
|
||||||
uint64_t& channelId,
|
|
||||||
uint32_t& sourceCoreId,
|
|
||||||
uint32_t& targetCoreId) {
|
|
||||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
|
||||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
|
||||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
|
||||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
|
||||||
return false;
|
|
||||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
|
||||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
|
||||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
|
||||||
uint64_t& channelId,
|
|
||||||
uint32_t& sourceCoreId,
|
|
||||||
uint32_t& targetCoreId) {
|
|
||||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
|
||||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
|
||||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
|
||||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
|
||||||
return false;
|
|
||||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
|
||||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
|
||||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
|
||||||
SmallVector<Value> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (int64_t value : values)
|
|
||||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
|
||||||
SmallVector<Value> constants;
|
|
||||||
constants.reserve(values.size());
|
|
||||||
for (int32_t value : values)
|
|
||||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
|
||||||
return constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
|
||||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
|
||||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RebatchKey {
|
|
||||||
unsigned inputCount = 0;
|
|
||||||
unsigned resultCount = 0;
|
|
||||||
unsigned weightCount = 0;
|
|
||||||
uint64_t phase = 0;
|
|
||||||
bool hasPhase = false;
|
|
||||||
uint64_t structureHash = 0;
|
|
||||||
|
|
||||||
bool operator==(const RebatchKey& other) const {
|
|
||||||
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
|
||||||
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RebatchKeyInfo {
|
|
||||||
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
|
||||||
|
|
||||||
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
|
||||||
|
|
||||||
static unsigned getHashValue(const RebatchKey& key) {
|
|
||||||
return static_cast<unsigned>(
|
|
||||||
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
|
||||||
};
|
|
||||||
|
|
||||||
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
|
||||||
|
|
||||||
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
|
||||||
|
|
||||||
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
|
||||||
|
|
||||||
RebatchKey computeRebatchKey(SpatCompute compute) {
|
|
||||||
llvm::hash_code structureHash =
|
|
||||||
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
|
||||||
|
|
||||||
for (Value weight : compute.getWeights())
|
|
||||||
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
|
||||||
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
|
||||||
structureHash = llvm::hash_combine(structureHash, *phase);
|
|
||||||
|
|
||||||
Block& body = compute.getBody().front();
|
|
||||||
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
|
||||||
for (BlockArgument arg : body.getArguments())
|
|
||||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
|
||||||
|
|
||||||
for (Operation& op : body) {
|
|
||||||
structureHash = llvm::hash_combine(
|
|
||||||
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
|
||||||
for (Type type : op.getResultTypes())
|
|
||||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
|
||||||
for (NamedAttribute attr : op.getAttrs())
|
|
||||||
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
|
||||||
return {static_cast<unsigned>(compute.getInputs().size()),
|
|
||||||
static_cast<unsigned>(compute.getResultTypes().size()),
|
|
||||||
static_cast<unsigned>(compute.getWeights().size()),
|
|
||||||
phase.value_or(0),
|
|
||||||
phase.has_value(),
|
|
||||||
static_cast<uint64_t>(structureHash)};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
|
||||||
if (!lhs || !rhs)
|
|
||||||
return false;
|
|
||||||
if (lhs.getInputs().size() != rhs.getInputs().size())
|
|
||||||
return false;
|
|
||||||
if (lhs.getResultTypes() != rhs.getResultTypes())
|
|
||||||
return false;
|
|
||||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
|
||||||
return false;
|
|
||||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
|
||||||
return false;
|
|
||||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto& lhsBlock = lhs.getBody().front();
|
|
||||||
auto& rhsBlock = rhs.getBody().front();
|
|
||||||
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
DenseMap<Value, Value> mappedValues;
|
|
||||||
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
|
||||||
if (lhsArg.getType() != rhsArg.getType())
|
|
||||||
return false;
|
|
||||||
mappedValues[lhsArg] = rhsArg;
|
|
||||||
}
|
|
||||||
auto lhsIt = lhsBlock.begin();
|
|
||||||
auto rhsIt = rhsBlock.begin();
|
|
||||||
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
|
||||||
Operation& lhsOp = *lhsIt;
|
|
||||||
Operation& rhsOp = *rhsIt;
|
|
||||||
|
|
||||||
if (lhsOp.getName() != rhsOp.getName())
|
|
||||||
return false;
|
|
||||||
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
|
||||||
return false;
|
|
||||||
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
|
||||||
return false;
|
|
||||||
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
|
||||||
auto mapped = mappedValues.find(lhsOperand);
|
|
||||||
if (mapped != mappedValues.end()) {
|
|
||||||
if (mapped->second != rhsOperand)
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (lhsOperand != rhsOperand)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
|
||||||
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
|
||||||
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
|
||||||
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
|
||||||
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
|
||||||
return false;
|
|
||||||
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
|
||||||
mappedValues[lhsResult] = rhsResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
|
||||||
OperationFolder constantFolder(funcOp.getContext());
|
|
||||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
|
||||||
DenseSet<Operation*> consumed;
|
|
||||||
DenseMap<Operation*, size_t> computeOrder;
|
|
||||||
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
|
||||||
|
|
||||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
|
||||||
computeOrder[compute.getOperation()] = index;
|
|
||||||
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
|
||||||
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t index = 0; index < computes.size(); ++index) {
|
|
||||||
auto anchor = computes[index];
|
|
||||||
if (consumed.contains(anchor))
|
|
||||||
continue;
|
|
||||||
if (anchor.getInputs().size() > 1)
|
|
||||||
continue;
|
|
||||||
if (!anchor.getResults().empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
SmallVector<SpatCompute> group {anchor};
|
|
||||||
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
|
||||||
if (auto coreId = getComputeCoreId(anchor))
|
|
||||||
usedCoreIds.insert(*coreId);
|
|
||||||
|
|
||||||
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
|
||||||
if (bucketIt == candidatesByKey.end())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (auto candidate : bucketIt->second) {
|
|
||||||
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
|
||||||
continue;
|
|
||||||
if (consumed.contains(candidate))
|
|
||||||
continue;
|
|
||||||
if (!areEquivalentForRebatch(anchor, candidate))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (auto coreId = getComputeCoreId(candidate))
|
|
||||||
if (!usedCoreIds.insert(*coreId).second)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
group.push_back(candidate);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (group.size() <= 1)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto insertionAnchor = group.front();
|
|
||||||
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
|
||||||
llvm::stable_sort(
|
|
||||||
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> weights;
|
|
||||||
weights.reserve(group.size() * anchor.getWeights().size());
|
|
||||||
SmallVector<Value> inputs;
|
|
||||||
inputs.reserve(group.size() * anchor.getInputs().size());
|
|
||||||
SmallVector<int32_t> coreIds;
|
|
||||||
coreIds.reserve(group.size());
|
|
||||||
bool haveAllCoreIds = true;
|
|
||||||
for (auto compute : group) {
|
|
||||||
llvm::append_range(weights, compute.getWeights());
|
|
||||||
llvm::append_range(inputs, compute.getInputs());
|
|
||||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
|
||||||
if (!coreIdAttr)
|
|
||||||
haveAllCoreIds = false;
|
|
||||||
else if (haveAllCoreIds)
|
|
||||||
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(insertionAnchor);
|
|
||||||
auto rebatched = SpatComputeBatch::create(rewriter,
|
|
||||||
insertionAnchor.getLoc(),
|
|
||||||
TypeRange {},
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
|
||||||
ValueRange(weights),
|
|
||||||
ValueRange(inputs));
|
|
||||||
rebatched.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
|
||||||
if (haveAllCoreIds)
|
|
||||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes;
|
|
||||||
SmallVector<Location> blockArgLocs;
|
|
||||||
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
|
||||||
blockArgTypes.push_back(arg.getType());
|
|
||||||
blockArgLocs.push_back(arg.getLoc());
|
|
||||||
}
|
|
||||||
auto* newBlock =
|
|
||||||
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto& anchorBlock = anchor.getBody().front();
|
|
||||||
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
|
||||||
mapper.map(oldArg, newArg);
|
|
||||||
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
|
||||||
for (Operation& anchorOp : anchorBlock) {
|
|
||||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
|
||||||
struct BatchReceiveEntry {
|
|
||||||
uint64_t channelId = 0;
|
|
||||||
uint32_t sourceCoreId = 0;
|
|
||||||
uint32_t targetCoreId = 0;
|
|
||||||
};
|
|
||||||
SmallVector<BatchReceiveEntry> entries;
|
|
||||||
entries.reserve(group.size());
|
|
||||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
|
||||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
|
||||||
BatchReceiveEntry entry;
|
|
||||||
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
|
||||||
return;
|
|
||||||
entries.push_back(entry);
|
|
||||||
++opIts[groupIndex];
|
|
||||||
}
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
channelIds.reserve(group.size());
|
|
||||||
sourceCoreIds.reserve(group.size());
|
|
||||||
targetCoreIds.reserve(group.size());
|
|
||||||
for (const BatchReceiveEntry& entry : entries) {
|
|
||||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
}
|
|
||||||
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
|
|
||||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
|
|
||||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
|
|
||||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
|
||||||
receiveOp.getLoc(),
|
|
||||||
receiveOp.getOutput().getType(),
|
|
||||||
channelIdValues,
|
|
||||||
sourceCoreIdValues,
|
|
||||||
targetCoreIdValues);
|
|
||||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
|
||||||
struct BatchSendEntry {
|
|
||||||
uint64_t channelId = 0;
|
|
||||||
uint32_t sourceCoreId = 0;
|
|
||||||
uint32_t targetCoreId = 0;
|
|
||||||
};
|
|
||||||
SmallVector<BatchSendEntry> entries;
|
|
||||||
entries.reserve(group.size());
|
|
||||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
|
||||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
|
||||||
BatchSendEntry entry;
|
|
||||||
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
|
||||||
return;
|
|
||||||
entries.push_back(entry);
|
|
||||||
++opIts[groupIndex];
|
|
||||||
}
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
channelIds.reserve(group.size());
|
|
||||||
sourceCoreIds.reserve(group.size());
|
|
||||||
targetCoreIds.reserve(group.size());
|
|
||||||
for (const BatchSendEntry& entry : entries) {
|
|
||||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
}
|
|
||||||
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
|
|
||||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
|
|
||||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
|
|
||||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
|
||||||
sendOp.getLoc(),
|
|
||||||
channelIdValues,
|
|
||||||
sourceCoreIdValues,
|
|
||||||
targetCoreIdValues,
|
|
||||||
mapper.lookup(sendOp.getInput()));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
|
||||||
for (auto& opIt : opIts)
|
|
||||||
++opIt;
|
|
||||||
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
|
||||||
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
|
||||||
mapper.map(originalResult, clonedResult);
|
|
||||||
for (auto& opIt : opIts)
|
|
||||||
++opIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto compute : group) {
|
|
||||||
compute->removeAttr(kRebatchPhaseAttrName);
|
|
||||||
consumed.insert(compute);
|
|
||||||
rewriter.eraseOp(compute);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
|
||||||
compute->removeAttr(kRebatchPhaseAttrName);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
|
||||||
auto eraseUnusedOps = [&](auto tag) {
|
|
||||||
using OpTy = decltype(tag);
|
|
||||||
SmallVector<OpTy> ops;
|
|
||||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
|
||||||
for (auto op : llvm::reverse(ops))
|
|
||||||
if (op->use_empty())
|
|
||||||
op.erase();
|
|
||||||
};
|
|
||||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
|
||||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
|
||||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
|
||||||
orderBilateralChannelOps(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
|
||||||
rebatchEquivalentComputes(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
|
||||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
|
||||||
compactBatchChannelRuns(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
|
||||||
compactRegularOpRuns(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
|
||||||
compactRowWiseWvmmRuns(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
|
||||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
|
||||||
compactBatchChannelRuns(funcOp);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
|
||||||
cleanupDeadPackingOps(funcOp);
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
|
||||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
|
||||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
|
||||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
|
||||||
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
Reference in New Issue
Block a user