better MaterializeMergeSchedule.cpp with %lane indexed batch computes

support for tensors of index values
This commit is contained in:
NiccoloN
2026-05-22 21:52:28 +02:00
parent 495186503c
commit c77ffa9c56
20 changed files with 398 additions and 300 deletions
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
return PimMemCopyOp::create(rewriter,
loc,
@@ -1,9 +1,10 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
return builder.getI32IntegerAttr(sizeInBytes);
}
@@ -9,6 +9,7 @@
#include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir;
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
}
static bool isCandidateAllocType(MemRefType type) {
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
return type && type.hasStaticShape() && type.getLayout().isIdentity()
&& hasByteSizedElementType(type.getElementType());
}
static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
}
static FailureOr<uint64_t>
+15 -8
View File
@@ -34,7 +34,9 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); }
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), idx);
}
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
@@ -74,11 +76,13 @@ SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Locat
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(
newCompute.getOperation(), static_cast<int32_t>(newCompute.getWeights().size()), static_cast<int32_t>(newCompute.getInputs().size()));
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
@@ -110,7 +114,8 @@ std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
@@ -145,8 +150,9 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(
newBatch.getOperation(), static_cast<int32_t>(newBatch.getWeights().size()), static_cast<int32_t>(newBatch.getInputs().size()));
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
@@ -155,7 +161,8 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type,
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
@@ -4,7 +4,6 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -131,7 +130,7 @@ struct MaterializerState {
DenseMap<CpuSlotKey, ComputeInstance, CpuSlotKeyInfo> cpuSlotToInstance;
DenseSet<ClassSlotKey, ClassSlotKeyInfo> materializedSlots;
DenseMap<ProducerKey, DenseSet<ClassId>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, SmallVector<ClassId, 4>, ProducerKeyInfo> producerDestClasses;
DenseMap<ProducerKey, DenseMap<ClassId, Value>, ProducerKeyInfo> availableValues;
DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps;
@@ -574,6 +573,77 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
}
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
SmallVector<APInt, 8> elements;
elements.reserve(values.size());
for (int64_t value : values)
elements.push_back(APInt(64, value));
auto type = RankedTensorType::get({static_cast<int64_t>(values.size())}, state.rewriter.getIndexType());
auto attr = DenseIntElementsAttr::get(type, elements);
return getOrCreateHostConstant(anchor, attr, type, state.constantFolder);
}
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
SmallVector<int64_t, 8> widened;
widened.reserve(values.size());
for (int32_t value : values)
widened.push_back(value);
return createIndexTensorConstant(state, anchor, ArrayRef<int64_t>(widened));
}
bool allEqual(ArrayRef<int64_t> values) {
assert(!values.empty() && "expected at least one value");
for (int64_t value : values.drop_front())
if (value != values.front())
return false;
return true;
}
bool allEqual(ArrayRef<int32_t> values) {
assert(!values.empty() && "expected at least one value");
for (int32_t value : values.drop_front())
if (value != values.front())
return false;
return true;
}
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int64_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
if (allEqual(values))
return createIndexConstant(state, materializedClass.op, values.front());
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected compute_batch lane argument");
Value table = createIndexTensorConstant(state, materializedClass.op, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
Value createLaneIndexedIndexValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<int32_t> values,
Location loc) {
assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class");
assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane");
if (allEqual(values))
return createIndexConstant(state, materializedClass.op, values.front());
auto batch = cast<SpatComputeBatch>(materializedClass.op);
auto laneArg = batch.getLaneArgument();
assert(laneArg && "expected compute_batch lane argument");
Value table = createIndexTensorConstant(state, materializedClass.op, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
FailureOr<SmallVector<ComputeInstance, 8>>
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
SmallVector<ComputeInstance, 8> peers;
@@ -623,14 +693,12 @@ Value createOriginalLaneValue(MaterializerState& state,
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
}
SmallVector<APInt, 8> laneValues;
SmallVector<int64_t, 8> laneValues;
laneValues.reserve(peers.size());
for (const ComputeInstance& peer : peers)
laneValues.push_back(APInt(64, peer.laneStart));
laneValues.push_back(peer.laneStart);
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType());
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
Value table = createIndexTensorConstant(state, materializedClass.op, laneValues);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
@@ -659,6 +727,12 @@ bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps)
return false;
}
void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) {
SmallVector<ClassId, 4>& destinations = state.producerDestClasses[key];
if (!llvm::is_contained(destinations, classId))
destinations.push_back(classId);
}
void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet<Operation*>& oldComputeOps) {
SmallVector<OpOperand*> uses;
for (OpOperand& use : oldValue.getUses())
@@ -693,7 +767,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
if (sourceClass == targetClass)
continue;
state.producerDestClasses[producerKey].insert(targetClass);
appendDestinationClass(state, producerKey, targetClass);
}
}
}
@@ -714,29 +788,70 @@ bool haveSameDestinationClasses(MaterializerState& state, ArrayRef<ProducerKey>
return true;
auto firstIt = state.producerDestClasses.find(keys.front());
DenseSet<ClassId> empty;
const DenseSet<ClassId>& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second;
ArrayRef<ClassId> first = firstIt == state.producerDestClasses.end() ? ArrayRef<ClassId>() : firstIt->second;
for (ProducerKey key : keys.drop_front()) {
auto it = state.producerDestClasses.find(key);
const DenseSet<ClassId>& current = it == state.producerDestClasses.end() ? empty : it->second;
ArrayRef<ClassId> current = it == state.producerDestClasses.end() ? ArrayRef<ClassId>() : it->second;
if (first.size() != current.size())
return false;
for (ClassId classId : first)
if (!current.contains(classId))
for (auto [lhs, rhs] : llvm::zip(first, current))
if (lhs != rhs)
return false;
}
return true;
}
SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, ProducerKey key) {
SmallVector<ClassId, 4> destinations;
ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey key) {
auto it = state.producerDestClasses.find(key);
if (it == state.producerDestClasses.end())
return destinations;
for (ClassId classId : it->second)
destinations.push_back(classId);
llvm::sort(destinations);
return destinations;
return {};
return it->second;
}
void appendSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one send");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (sourceClass.isBatch) {
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
return;
}
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
}
}
Value appendScalarReceive(MaterializerState& state,
MaterializedClass& targetClass,
Type type,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class");
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, targetClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId);
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue)
.getOutput();
}
Value appendReceive(MaterializerState& state,
@@ -746,50 +861,169 @@ Value appendReceive(MaterializerState& state,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive");
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds);
if (targetClass.isBatch) {
return SpatChannelReceiveBatchOp::create(
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
.getOutput();
Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc);
return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput();
}
if (channelIds.size() != 1) {
return SpatChannelReceiveTensorOp::create(
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
.getOutput();
}
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
.getOutput();
assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time");
return appendScalarReceive(
state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
}
Value appendHostReceive(MaterializerState& state,
MaterializedClass& sourceClass,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPointAfter(sourceClass.op);
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
Value appendPackedScalarReceives(MaterializerState& state,
MaterializedClass& targetClass,
Type fragmentType,
Type packedType,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class");
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive");
if (sourceClass.isBatch) {
return SpatChannelReceiveTensorOp::create(
state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues)
.getOutput();
SmallVector<Value, 8> fragments;
fragments.reserve(channelIds.size());
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
fragments.push_back(appendScalarReceive(
state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc));
}
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
.getOutput();
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value packed = fragments.front();
if (fragments.size() != 1)
packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult();
if (packed.getType() != packedType)
packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult();
return packed;
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = payload;
return success();
}
if (!sourceClass.isBatch && !targetClass.isBatch) {
int64_t channelId = state.nextChannelId++;
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
SmallVector<int64_t, 1> channelIds {channelId};
SmallVector<int32_t, 1> sourceCoreIds {sourceCpu};
SmallVector<int32_t, 1> targetCoreIds {targetCpu};
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (!sourceClass.isBatch && targetClass.isBatch) {
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(targetClass.cpus.size());
sourceCoreIds.reserve(targetClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
for (CpuId targetCpu : targetClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(sourceCpu);
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey)
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication because source lanes are not contiguous");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType))
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication for non-static ranked tensor payload");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(targetCpu);
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendPackedScalarReceives(
state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.availableValues[*packedKey][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
}
appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
}
LogicalResult
@@ -821,207 +1055,50 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
if (!payloadType || !payloadType.hasStaticShape())
return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor");
auto laneArg = batch.getLaneArgument();
if (!laneArg)
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
auto outputArg = batch.getOutputArgument(resultIndex);
if (!outputArg)
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult, 4> offsets;
SmallVector<OpFoldResult, 4> sizes;
SmallVector<OpFoldResult, 4> strides;
offsets.reserve(payloadType.getRank());
sizes.reserve(payloadType.getRank());
strides.reserve(payloadType.getRank());
auto laneArg = batch.getLaneArgument();
if (!laneArg)
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
offsets.push_back(*laneArg);
sizes.push_back(state.rewriter.getIndexAttr(1));
strides.push_back(state.rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
offsets.push_back(state.rewriter.getIndexAttr(0));
sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim)));
strides.push_back(state.rewriter.getIndexAttr(1));
}
auto outputArg = batch.getOutputArgument(resultIndex);
if (!outputArg)
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
return success();
}
void appendScalarSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
}
void appendBatchSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
SmallVector<Value, 8> targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds);
SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload);
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = payload;
return success();
}
if (!sourceClass.isBatch && !targetClass.isBatch) {
int64_t channelId = state.nextChannelId++;
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
Value received = appendReceive(state,
targetClass,
payload.getType(),
ArrayRef<int64_t>(channelId),
ArrayRef<int32_t>(sourceCpu),
ArrayRef<int32_t>(targetCpu),
loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (!sourceClass.isBatch && targetClass.isBatch) {
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(targetClass.cpus.size());
sourceCoreIds.reserve(targetClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (CpuId targetCpu : targetClass.cpus) {
int64_t channelId = state.nextChannelId++;
channelIds.push_back(channelId);
sourceCoreIds.push_back(static_cast<int32_t>(sourceClass.cpus.front()));
targetCoreIds.push_back(static_cast<int32_t>(targetCpu));
appendScalarSend(state,
sourceClass,
payload,
channelId,
static_cast<int32_t>(sourceClass.cpus.front()),
static_cast<int32_t>(targetCpu),
loc);
}
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey)
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
"lanes are not contiguous in send order");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType))
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus.front()));
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc);
state.availableValues[*packedKey][targetClass.id] = received;
return success();
}
if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(targetClass.cpus.size());
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(static_cast<int32_t>(targetClass.cpus[lane]));
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
}
return sourceClass.op->emitError("unhandled materialized communication pattern");
}
LogicalResult emitHostCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
(void) keys;
(void) loc;
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
return success();
if (!sourceClass.hostOutputs.empty())
return setHostOutputValue(state, sourceClass, originalOutput, payload);
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
SmallVector<int32_t, 8> targetCoreIds;
channelIds.reserve(sourceClass.cpus.size());
sourceCoreIds.reserve(sourceClass.cpus.size());
targetCoreIds.reserve(sourceClass.cpus.size());
for (CpuId sourceCpu : sourceClass.cpus) {
channelIds.push_back(state.nextChannelId++);
sourceCoreIds.push_back(static_cast<int32_t>(sourceCpu));
targetCoreIds.push_back(0);
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
state.hostReplacements[originalOutput] = received;
return success();
return setHostOutputValue(state, sourceClass, originalOutput, payload);
}
LogicalResult emitOutputFanout(MaterializerState& state,
@@ -1034,7 +1111,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success();
if (!sourceClass.isBatch) {
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
if (failed(
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure();
@@ -1048,7 +1125,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return sourceClass.op->emitError(
"cannot materialize batched output whose lanes have different destination equivalence classes");
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
for (ClassId destinationClass : getDestinationClasses(state, keys.front()))
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure();
@@ -660,12 +660,12 @@ public:
emitMergeIrCounts("after-materialization", func);
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
signalPassFailure();
return;
}
emitMergeIrCounts("after-post-merge-compaction", func);
emitMergeIrCounts("after-post-merge-compaction", func);*/
{
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");