|
|
|
@@ -4,7 +4,6 @@
|
|
|
|
|
#include "mlir/IR/IRMapping.h"
|
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
|
|
|
|
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
|
|
@@ -131,7 +130,7 @@ struct MaterializerState {
|
|
|
|
|
DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance;
|
|
|
|
|
DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots;
|
|
|
|
|
|
|
|
|
|
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> producerDestClasses;
|
|
|
|
|
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
|
|
|
|
|
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues;
|
|
|
|
|
DenseMap<Value, Value> hostReplacements;
|
|
|
|
|
DenseSet<Operation*> oldComputeOps;
|
|
|
|
@@ -574,6 +573,77 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
|
|
|
|
|
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
|
|
|
|
SmallVector<APInt, 8> elements;
|
|
|
|
|
elements.reserve(values.size());
|
|
|
|
|
for (int64_t value : values)
|
|
|
|
|
elements.push_back(APInt(64, value));
|
|
|
|
|
|
|
|
|
|
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
|
|
|
|
|
auto attr = DenseIntElementsAttr::get(type, elements);
|
|
|
|
|
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
|
|
|
|
|
SmallVector<int64_t, 8> widened;
|
|
|
|
|
widened.reserve(values.size());
|
|
|
|
|
for (int32_t value : values)
|
|
|
|
|
widened.push_back(value);
|
|
|
|
|
return createIndexTensorConstant(state, anchor, ArrayRef<int64_t>(widened));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool allEqual(ArrayRef<int64_t> values) {
|
|
|
|
|
assert(!values.empty() && "expected at least one value");
|
|
|
|
|
for (int64_t value : values.drop_front())
|
|
|
|
|
if (value != values.front())
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool allEqual(ArrayRef<int32_t> values) {
|
|
|
|
|
assert(!values.empty() && "expected at least one value");
|
|
|
|
|
for (int32_t value : values.drop_front())
|
|
|
|
|
if (value != values.front())
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value createLaneIndexedIndexValue(MaterializerState& state,
|
|
|
|
|
MaterializedClass& materializedClass,
|
|
|
|
|
ArrayRef<int64_t> values,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
|
|
|
|
|
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
|
|
|
|
|
|
|
|
|
|
if (allEqual(values))
|
|
|
|
|
return createIndexConstant(state, materializedClass.op, values.front());
|
|
|
|
|
|
|
|
|
|
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
|
|
|
|
auto laneArg = batch.getLaneArgument();
|
|
|
|
|
assert(laneArg && "expected compute_batch lane argument");
|
|
|
|
|
|
|
|
|
|
Value table = createIndexTensorConstant(state, materializedClass.op, values);
|
|
|
|
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value createLaneIndexedIndexValue(MaterializerState& state,
|
|
|
|
|
MaterializedClass& materializedClass,
|
|
|
|
|
ArrayRef<int32_t> values,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
|
|
|
|
|
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
|
|
|
|
|
|
|
|
|
|
if (allEqual(values))
|
|
|
|
|
return createIndexConstant(state, materializedClass.op, values.front());
|
|
|
|
|
|
|
|
|
|
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
|
|
|
|
auto laneArg = batch.getLaneArgument();
|
|
|
|
|
assert(laneArg && "expected compute_batch lane argument");
|
|
|
|
|
|
|
|
|
|
Value table = createIndexTensorConstant(state, materializedClass.op, values);
|
|
|
|
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FailureOr<SmallVector<ComputeInstance, 8>>
|
|
|
|
|
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
|
|
|
|
SmallVector<ComputeInstance, 8> peers;
|
|
|
|
@@ -623,14 +693,12 @@ Value createOriginalLaneValue(MaterializerState& state,
|
|
|
|
|
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<APInt, 8> laneValues;
|
|
|
|
|
SmallVector<int64_t, 8> laneValues;
|
|
|
|
|
laneValues.reserve(peers.size());
|
|
|
|
|
for (const ComputeInstance& peer : peers)
|
|
|
|
|
laneValues.push_back(APInt(64, peer.laneStart));
|
|
|
|
|
laneValues.push_back(peer.laneStart);
|
|
|
|
|
|
|
|
|
|
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType());
|
|
|
|
|
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
|
|
|
|
|
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
|
|
|
|
|
Value table = createIndexTensorConstant(state, materializedClass.op, laneValues);
|
|
|
|
|
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -659,6 +727,12 @@ bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps)
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) {
|
|
|
|
|
SmallVector<ClassId, 4>& destinations = state.producerDestClasses[key];
|
|
|
|
|
if (!llvm::is_contained(destinations, classId))
|
|
|
|
|
destinations.push_back(classId);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) {
|
|
|
|
|
SmallVector<OpOperand*> uses;
|
|
|
|
|
for (OpOperand& use : oldValue.getUses())
|
|
|
|
@@ -693,7 +767,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|
|
|
|
if (sourceClass == targetClass)
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
state.producerDestClasses[producerKey].insert(targetClass);
|
|
|
|
|
appendDestinationClass(state, producerKey, targetClass);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@@ -714,29 +788,70 @@ bool haveSameDestinationClasses(MaterializerState& state, ArrayRef<ProducerKey>
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
|
|
auto firstIt = state.producerDestClasses.find(keys.front());
|
|
|
|
|
DenseSet<ClassId> empty;
|
|
|
|
|
const DenseSet<ClassId>& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second;
|
|
|
|
|
ArrayRef<ClassId> first = firstIt == state.producerDestClasses.end() ? ArrayRef<ClassId>() : firstIt->second;
|
|
|
|
|
for (ProducerKey key : keys.drop_front()) {
|
|
|
|
|
auto it = state.producerDestClasses.find(key);
|
|
|
|
|
const DenseSet<ClassId>& current = it == state.producerDestClasses.end() ? empty : it->second;
|
|
|
|
|
ArrayRef<ClassId> current = it == state.producerDestClasses.end() ? ArrayRef<ClassId>() : it->second;
|
|
|
|
|
if (first.size() != current.size())
|
|
|
|
|
return false;
|
|
|
|
|
for (ClassId classId : first)
|
|
|
|
|
if (!current.contains(classId))
|
|
|
|
|
for (auto [lhs, rhs] : llvm::zip(first, current))
|
|
|
|
|
if (lhs != rhs)
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, ProducerKey key) {
|
|
|
|
|
SmallVector<ClassId, 4> destinations;
|
|
|
|
|
ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey key) {
|
|
|
|
|
auto it = state.producerDestClasses.find(key);
|
|
|
|
|
if (it == state.producerDestClasses.end())
|
|
|
|
|
return destinations;
|
|
|
|
|
for (ClassId classId : it->second)
|
|
|
|
|
destinations.push_back(classId);
|
|
|
|
|
llvm::sort(destinations);
|
|
|
|
|
return destinations;
|
|
|
|
|
return {};
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void appendSend(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
Value payload,
|
|
|
|
|
ArrayRef<int64_t> channelIds,
|
|
|
|
|
ArrayRef<int32_t> sourceCoreIds,
|
|
|
|
|
ArrayRef<int32_t> targetCoreIds,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
|
|
|
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
|
|
|
|
assert(!channelIds.empty() && "expected at least one send");
|
|
|
|
|
|
|
|
|
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch) {
|
|
|
|
|
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
|
|
|
|
|
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
|
|
|
|
|
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
|
|
|
|
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
|
|
|
|
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
|
|
|
|
|
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
|
|
|
|
|
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
|
|
|
|
|
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value appendScalarReceive(MaterializerState& state,
|
|
|
|
|
MaterializedClass& targetClass,
|
|
|
|
|
Type type,
|
|
|
|
|
int64_t channelId,
|
|
|
|
|
int32_t sourceCoreId,
|
|
|
|
|
int32_t targetCoreId,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class");
|
|
|
|
|
|
|
|
|
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
|
|
|
|
Value channelIdValue = createIndexConstant(state, targetClass.op, channelId);
|
|
|
|
|
Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId);
|
|
|
|
|
Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId);
|
|
|
|
|
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue)
|
|
|
|
|
.getOutput();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value appendReceive(MaterializerState& state,
|
|
|
|
@@ -746,50 +861,169 @@ Value appendReceive(MaterializerState& state,
|
|
|
|
|
ArrayRef<int32_t> sourceCoreIds,
|
|
|
|
|
ArrayRef<int32_t> targetCoreIds,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
|
|
|
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
|
|
|
|
assert(!channelIds.empty() && "expected at least one receive");
|
|
|
|
|
|
|
|
|
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
|
|
|
|
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
|
|
|
|
|
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
|
|
|
|
|
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds);
|
|
|
|
|
|
|
|
|
|
if (targetClass.isBatch) {
|
|
|
|
|
return SpatChannelReceiveBatchOp::create(
|
|
|
|
|
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
|
|
|
|
.getOutput();
|
|
|
|
|
Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc);
|
|
|
|
|
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc);
|
|
|
|
|
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc);
|
|
|
|
|
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (channelIds.size() != 1) {
|
|
|
|
|
return SpatChannelReceiveTensorOp::create(
|
|
|
|
|
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
|
|
|
|
.getOutput();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SpatChannelReceiveOp::create(
|
|
|
|
|
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
|
|
|
|
.getOutput();
|
|
|
|
|
assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time");
|
|
|
|
|
return appendScalarReceive(
|
|
|
|
|
state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value appendHostReceive(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
Type type,
|
|
|
|
|
ArrayRef<int64_t> channelIds,
|
|
|
|
|
ArrayRef<int32_t> sourceCoreIds,
|
|
|
|
|
ArrayRef<int32_t> targetCoreIds,
|
|
|
|
|
Location loc) {
|
|
|
|
|
state.rewriter.setInsertionPointAfter(sourceClass.op);
|
|
|
|
|
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
|
|
|
|
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
|
|
|
|
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
|
|
|
|
|
Value appendPackedScalarReceives(MaterializerState& state,
|
|
|
|
|
MaterializedClass& targetClass,
|
|
|
|
|
Type fragmentType,
|
|
|
|
|
Type packedType,
|
|
|
|
|
ArrayRef<int64_t> channelIds,
|
|
|
|
|
ArrayRef<int32_t> sourceCoreIds,
|
|
|
|
|
ArrayRef<int32_t> targetCoreIds,
|
|
|
|
|
Location loc) {
|
|
|
|
|
assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class");
|
|
|
|
|
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
|
|
|
|
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
|
|
|
|
assert(!channelIds.empty() && "expected at least one receive");
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch) {
|
|
|
|
|
return SpatChannelReceiveTensorOp::create(
|
|
|
|
|
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
|
|
|
|
|
.getOutput();
|
|
|
|
|
SmallVector<Value, 8> fragments;
|
|
|
|
|
fragments.reserve(channelIds.size());
|
|
|
|
|
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
|
|
|
|
fragments.push_back(appendScalarReceive(
|
|
|
|
|
state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
|
|
|
|
|
return SpatChannelReceiveOp::create(
|
|
|
|
|
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
|
|
|
|
.getOutput();
|
|
|
|
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
|
|
|
|
|
|
|
|
|
Value packed = fragments.front();
|
|
|
|
|
if (fragments.size() != 1)
|
|
|
|
|
packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult();
|
|
|
|
|
|
|
|
|
|
if (packed.getType() != packedType)
|
|
|
|
|
packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult();
|
|
|
|
|
|
|
|
|
|
return packed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
MaterializedClass& targetClass,
|
|
|
|
|
ArrayRef<ProducerKey> keys,
|
|
|
|
|
Value payload,
|
|
|
|
|
Location loc) {
|
|
|
|
|
if (sourceClass.id == targetClass.id) {
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = payload;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.isBatch && !targetClass.isBatch) {
|
|
|
|
|
int64_t channelId = state.nextChannelId++;
|
|
|
|
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
|
|
|
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 1> channelIds {channelId};
|
|
|
|
|
SmallVector<int32_t, 1> sourceCoreIds {sourceCpu};
|
|
|
|
|
SmallVector<int32_t, 1> targetCoreIds {targetCpu};
|
|
|
|
|
|
|
|
|
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received =
|
|
|
|
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.isBatch && targetClass.isBatch) {
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(targetClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
|
|
|
|
for (CpuId targetCpu : targetClass.cpus) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(sourceCpu);
|
|
|
|
|
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received =
|
|
|
|
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch && !targetClass.isBatch) {
|
|
|
|
|
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
|
|
|
|
if (!packedKey)
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batch-to-scalar communication because source lanes are not contiguous");
|
|
|
|
|
|
|
|
|
|
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
|
|
|
|
if (failed(packedType))
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batch-to-scalar communication for non-static ranked tensor payload");
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
|
|
|
|
for (CpuId sourceCpu : sourceClass.cpus) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
|
|
|
targetCoreIds.push_back(targetCpu);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received = appendPackedScalarReceives(
|
|
|
|
|
state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
|
|
|
|
|
state.availableValues[*packedKey][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch && targetClass.isBatch) {
|
|
|
|
|
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batch communication between equivalence classes of different sizes");
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
|
|
|
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received =
|
|
|
|
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
@@ -821,207 +1055,50 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
|
|
|
|
if (!payloadType || !payloadType.hasStaticShape())
|
|
|
|
|
return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor");
|
|
|
|
|
|
|
|
|
|
auto laneArg = batch.getLaneArgument();
|
|
|
|
|
if (!laneArg)
|
|
|
|
|
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
|
|
|
|
|
|
|
|
|
|
auto outputArg = batch.getOutputArgument(resultIndex);
|
|
|
|
|
if (!outputArg)
|
|
|
|
|
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
|
|
|
|
|
|
|
|
|
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
|
|
|
|
|
|
|
|
SmallVector<OpFoldResult, 4> offsets;
|
|
|
|
|
SmallVector<OpFoldResult, 4> sizes;
|
|
|
|
|
SmallVector<OpFoldResult, 4> strides;
|
|
|
|
|
offsets.reserve(payloadType.getRank());
|
|
|
|
|
sizes.reserve(payloadType.getRank());
|
|
|
|
|
strides.reserve(payloadType.getRank());
|
|
|
|
|
auto laneArg = batch.getLaneArgument();
|
|
|
|
|
if (!laneArg)
|
|
|
|
|
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
|
|
|
|
|
|
|
|
|
|
offsets.push_back(*laneArg);
|
|
|
|
|
sizes.push_back(state.rewriter.getIndexAttr(1));
|
|
|
|
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
|
|
|
|
|
|
|
|
|
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
|
|
|
|
|
offsets.push_back(state.rewriter.getIndexAttr(0));
|
|
|
|
|
sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim)));
|
|
|
|
|
strides.push_back(state.rewriter.getIndexAttr(1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto outputArg = batch.getOutputArgument(resultIndex);
|
|
|
|
|
if (!outputArg)
|
|
|
|
|
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
|
|
|
|
|
|
|
|
|
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void appendScalarSend(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
Value payload,
|
|
|
|
|
int64_t channelId,
|
|
|
|
|
int32_t sourceCoreId,
|
|
|
|
|
int32_t targetCoreId,
|
|
|
|
|
Location loc) {
|
|
|
|
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
|
|
|
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
|
|
|
|
|
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
|
|
|
|
|
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
|
|
|
|
|
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void appendBatchSend(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
Value payload,
|
|
|
|
|
ArrayRef<int64_t> channelIds,
|
|
|
|
|
ArrayRef<int32_t> sourceCoreIds,
|
|
|
|
|
ArrayRef<int32_t> targetCoreIds,
|
|
|
|
|
Location loc) {
|
|
|
|
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
|
|
|
|
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
|
|
|
|
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
|
|
|
|
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
|
|
|
|
|
SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
MaterializedClass& targetClass,
|
|
|
|
|
ArrayRef<ProducerKey> keys,
|
|
|
|
|
Value payload,
|
|
|
|
|
Location loc) {
|
|
|
|
|
if (sourceClass.id == targetClass.id) {
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = payload;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.isBatch && !targetClass.isBatch) {
|
|
|
|
|
int64_t channelId = state.nextChannelId++;
|
|
|
|
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
|
|
|
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
|
|
|
|
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
|
|
|
|
|
Value received = appendReceive(state,
|
|
|
|
|
targetClass,
|
|
|
|
|
payload.getType(),
|
|
|
|
|
ArrayRef<int64_t>(channelId),
|
|
|
|
|
ArrayRef<int32_t>(sourceCpu),
|
|
|
|
|
ArrayRef<int32_t>(targetCpu),
|
|
|
|
|
loc);
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.isBatch && targetClass.isBatch) {
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(targetClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
for (CpuId targetCpu : targetClass.cpus) {
|
|
|
|
|
int64_t channelId = state.nextChannelId++;
|
|
|
|
|
channelIds.push_back(channelId);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceClass.cpus.front()));
|
|
|
|
|
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
|
|
|
|
|
appendScalarSend(state,
|
|
|
|
|
sourceClass,
|
|
|
|
|
payload,
|
|
|
|
|
channelId,
|
|
|
|
|
static_cast<int32_t>(sourceClass.cpus.front()),
|
|
|
|
|
static_cast<int32_t>(targetCpu),
|
|
|
|
|
loc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value received =
|
|
|
|
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch && !targetClass.isBatch) {
|
|
|
|
|
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
|
|
|
|
if (!packedKey)
|
|
|
|
|
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
|
|
|
|
|
"lanes are not contiguous in send order");
|
|
|
|
|
|
|
|
|
|
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
|
|
|
|
if (failed(packedType))
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload");
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
for (CpuId sourceCpu : sourceClass.cpus) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
|
|
|
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus.front()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
state.availableValues[*packedKey][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sourceClass.isBatch && targetClass.isBatch) {
|
|
|
|
|
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batch communication between equivalence classes of different sizes");
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(targetClass.cpus.size());
|
|
|
|
|
|
|
|
|
|
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
|
|
|
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received =
|
|
|
|
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
for (ProducerKey key : keys)
|
|
|
|
|
state.availableValues[key][targetClass.id] = received;
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return sourceClass.op->emitError("unhandled materialized communication pattern");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult emitHostCommunication(MaterializerState& state,
|
|
|
|
|
MaterializedClass& sourceClass,
|
|
|
|
|
ArrayRef<ProducerKey> keys,
|
|
|
|
|
Value payload,
|
|
|
|
|
Value originalOutput,
|
|
|
|
|
Location loc) {
|
|
|
|
|
(void) keys;
|
|
|
|
|
(void) loc;
|
|
|
|
|
|
|
|
|
|
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
|
|
|
|
return success();
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.hostOutputs.empty())
|
|
|
|
|
return setHostOutputValue(state, sourceClass, originalOutput, payload);
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t, 8> channelIds;
|
|
|
|
|
SmallVector<int32_t, 8> sourceCoreIds;
|
|
|
|
|
SmallVector<int32_t, 8> targetCoreIds;
|
|
|
|
|
channelIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
sourceCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
targetCoreIds.reserve(sourceClass.cpus.size());
|
|
|
|
|
for (CpuId sourceCpu : sourceClass.cpus) {
|
|
|
|
|
channelIds.push_back(state.nextChannelId++);
|
|
|
|
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
|
|
|
|
|
targetCoreIds.push_back(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
Value received =
|
|
|
|
|
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
|
|
|
|
state.hostReplacements[originalOutput] = received;
|
|
|
|
|
return success();
|
|
|
|
|
return setHostOutputValue(state, sourceClass, originalOutput, payload);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult emitOutputFanout(MaterializerState& state,
|
|
|
|
@@ -1034,7 +1111,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|
|
|
|
return success();
|
|
|
|
|
|
|
|
|
|
if (!sourceClass.isBatch) {
|
|
|
|
|
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
|
|
|
|
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
|
|
|
|
|
if (failed(
|
|
|
|
|
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
|
|
|
|
return failure();
|
|
|
|
@@ -1048,7 +1125,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|
|
|
|
return sourceClass.op->emitError(
|
|
|
|
|
"cannot materialize batched output whose lanes have different destination equivalence classes");
|
|
|
|
|
|
|
|
|
|
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
|
|
|
|
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
|
|
|
|
|
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|