Vgg 16 works and also resnet

This commit is contained in:
ilgeco
2026-07-01 13:49:21 +02:00
parent f5e1c2e706
commit 8d3eb929f6
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
@@ -96,11 +97,35 @@ std::optional<ProducerKey>
getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex);
std::optional<ProducerKey> getProjectedWholeBatchReplacementProducer(MaterializerState& state,
tensor::ExtractSliceOp extract);
LogicalResult recordLocalPackedSlotValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<ProducerKey> keys,
Value packed);
FailureOr<Value> materializeProjectedWholeBatchExtractReplacement(MaterializerState& state,
MaterializedClass& targetClass,
tensor::ExtractSliceOp extract,
ProducerKey producer,
IRMapping* mapper = nullptr);
enum class MaterializedLayoutKind {
DenseLogical,
RowPackedNCHWFromRows,
Unknown
};
struct MaterializedTensorView {
Value value;
RankedTensorType logicalType;
RankedTensorType physicalType;
MaterializedLayoutKind layout = MaterializedLayoutKind::Unknown;
};
FailureOr<Value> materializeLogicalTensorView(MaterializerState& state,
MaterializedClass& targetClass,
MaterializedTensorView view,
Operation* anchor,
StringRef context,
std::optional<ProducerKey> producer = std::nullopt,
IRMapping* mapper = nullptr);
bool isConstantLike(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
@@ -1484,6 +1509,66 @@ FailureOr<Value> materializeTensorValueForMaterializedClassUse(MaterializerState
return appendInput(state, targetClass, value);
}
MaterializedLayoutKind classifyMaterializedTensorLayout(RankedTensorType logicalType, RankedTensorType physicalType) {
if (logicalType == physicalType)
return MaterializedLayoutKind::DenseLogical;
if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape())
return MaterializedLayoutKind::Unknown;
if (logicalType.getRank() != 4 || physicalType.getRank() != 4)
return MaterializedLayoutKind::Unknown;
if (logicalType.getElementType() != physicalType.getElementType()
|| logicalType.getEncoding() != physicalType.getEncoding())
return MaterializedLayoutKind::Unknown;
if (logicalType.getDimSize(0) != 1 || physicalType.getDimSize(2) != 1)
return MaterializedLayoutKind::Unknown;
if (logicalType.getDimSize(1) != physicalType.getDimSize(1)
|| logicalType.getDimSize(2) != physicalType.getDimSize(0)
|| logicalType.getDimSize(3) != physicalType.getDimSize(3))
return MaterializedLayoutKind::Unknown;
return MaterializedLayoutKind::RowPackedNCHWFromRows;
}
FailureOr<Value> materializeLogicalTensorView(MaterializerState& state,
MaterializedClass& targetClass,
MaterializedTensorView view,
Operation* anchor,
StringRef context,
std::optional<ProducerKey> producer,
IRMapping* mapper) {
FailureOr<Value> localizedValue = materializeTensorValueForMaterializedClassUse(
state, targetClass, view.value, anchor, context, producer, mapper);
if (failed(localizedValue))
return failure();
if (view.layout == MaterializedLayoutKind::DenseLogical)
return *localizedValue;
if (view.layout == MaterializedLayoutKind::RowPackedNCHWFromRows) {
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value empty = tensor::EmptyOp::create(state.rewriter,
anchor->getLoc(),
view.logicalType.getShape(),
view.logicalType.getElementType())
.getResult();
auto permutation = DenseI64ArrayAttr::get(anchor->getContext(), {2, 1, 0, 3});
auto transpose =
linalg::TransposeOp::create(state.rewriter, anchor->getLoc(), *localizedValue, empty, permutation);
return transpose.getResult().front();
}
InFlightDiagnostic diagnostic = anchor->emitError("materialized tensor replacement changed physical layout")
<< " logicalType=" << view.logicalType << " physicalType=" << view.physicalType;
if (producer) {
diagnostic << " producerOp='" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart
<< " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex;
}
return failure();
}
std::optional<Value> mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass,
Operation* anchor,
BlockArgument externalArg) {
@@ -1870,6 +1955,58 @@ Value getPackedSliceForDynamicRunIndex(
return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0));
}
FailureOr<Value> materializeIndexedBatchRunPackedValue(MaterializerState& state,
MaterializedClass& targetClass,
IndexedBatchRunValue& run,
ProducerKey key,
Location loc) {
if (!run.packed)
return failure();
size_t flattenedIndexBase = 0;
for (const PackedScalarRunSlot& slot : run.slots) {
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
auto keyIt = llvm::find(slot.keys, key);
if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key))) {
flattenedIndexBase += slot.keys.size();
continue;
}
FailureOr<Value> packed = materializeTensorValueForMaterializedClassUse(
state,
targetClass,
run.packed,
targetClass.op,
"indexed batch run packed resolution tried to reuse a tensor from another materialized class");
if (failed(packed))
return failure();
if (!contiguousKey) {
size_t globalFragmentIndex = flattenedIndexBase + static_cast<size_t>(std::distance(slot.keys.begin(), keyIt));
return getPackedSliceForRunIndex(state, targetClass.op, *packed, run.fragmentType, globalFragmentIndex, loc);
}
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
if (failed(slotPackedType))
return failure();
int64_t rowOffset = static_cast<int64_t>(flattenedIndexBase) * run.fragmentType.getDimSize(0);
Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset);
Value slotPacked =
createDim0ExtractSlice(state, loc, *packed, firstOffset, (*slotPackedType).getDimSize(0));
if (*contiguousKey == key)
return slotPacked;
std::optional<Value> sliced = extractPackedProducerSlice(state, targetClass, *contiguousKey, slotPacked, key);
if (!sliced)
return failure();
return *sliced;
}
return failure();
}
using IndexedFragmentBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
using IndexedInsertOffsetBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
@@ -2017,11 +2154,15 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
continue;
size_t flattenedIndexBase = 0;
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
auto keyIt = llvm::find(slot.keys, key);
if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key)))
{
flattenedIndexBase += slot.keys.size();
continue;
}
MaterializedClass& materializedClass = state.classes[classId];
state.rewriter.setInsertionPoint(materializedClass.body->getTerminator());
@@ -2032,11 +2173,12 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
return std::nullopt;
if (!contiguousKey) {
size_t globalFragmentIndex = flattenedIndexBase + static_cast<size_t>(std::distance(slot.keys.begin(), keyIt));
Value sliced = getPackedSliceForRunIndex(state,
materializedClass.op,
*packed,
run.fragmentType,
static_cast<size_t>(std::distance(slot.keys.begin(), keyIt)),
globalFragmentIndex,
(*packed).getLoc());
record(key, classId, sliced);
return sliced;
@@ -2046,8 +2188,10 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
if (failed(slotPackedType))
return std::nullopt;
int64_t rowOffset = static_cast<int64_t>(flattenedIndexBase) * run.fragmentType.getDimSize(0);
Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset);
Value slotPacked =
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
createDim0ExtractSlice(state, (*packed).getLoc(), *packed, firstOffset, (*slotPackedType).getDimSize(0));
if (*contiguousKey == key) {
record(key, classId, slotPacked);
@@ -3862,8 +4006,10 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues.record(key, targetClass.id, payload);
if (keys.size() == 1)
state.availableValues.record(keys.front(), targetClass.id, payload);
else if (failed(recordLocalPackedSlotValue(state, targetClass, keys, payload)))
return failure();
return success();
}
@@ -4080,8 +4226,10 @@ LogicalResult emitOutputFanout(MaterializerState& state,
if (!recordedProjectedHostFragments && failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
return failure();
for (ProducerKey key : keys)
state.availableValues.record(key, sourceClass.id, payload);
if (keys.size() == 1)
state.availableValues.record(keys.front(), sourceClass.id, payload);
else if (failed(recordLocalPackedSlotValue(state, sourceClass, keys, payload)))
return failure();
return success();
}
@@ -4108,7 +4256,7 @@ struct WholeBatchFragmentGroup {
SmallVector<int64_t, 16> sourceLanes;
Value packed;
RankedTensorType slotPackedType;
SmallVector<int64_t, 16> slotIndices;
SmallVector<int64_t, 16> packedRowOffsets;
SmallVector<std::pair<Value, int64_t>, 16> directFragments;
SmallVector<Operation*, 16> redundantReceives;
};
@@ -4232,13 +4380,10 @@ FailureOr<Value> extractPackedSlotForIndex(MaterializerState& state,
MaterializedClass& targetClass,
Value packed,
RankedTensorType slotPackedType,
Value slotIndex,
Value packedRowOffset,
Location loc) {
FailureOr<Value> firstOffset =
scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc);
if (failed(firstOffset))
return failure();
return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0));
return createDim0ExtractSliceInClass(
state, targetClass, loc, packed, packedRowOffset, slotPackedType.getDimSize(0));
}
SmallVector<ProducerKey, 16> flattenPackedScalarRunKeys(const PackedScalarRunValue& run) {
@@ -4655,7 +4800,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
if (failed(slotPackedType))
return failure();
WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType);
group.slotIndices.push_back(slotIndex);
group.packedRowOffsets.push_back(static_cast<int64_t>(flattenedIndexBase) * plan.rowsPerLane);
group.outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
flattenedIndexBase += slot.keys.size();
continue;
@@ -4663,7 +4808,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType);
for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) {
group.slotIndices.push_back(flattenedIndexBase + keyIndex);
group.packedRowOffsets.push_back(static_cast<int64_t>(flattenedIndexBase + keyIndex) * plan.rowsPerLane);
group.outputOffsets.push_back(static_cast<int64_t>(fragmentKey.instance.laneStart) * plan.rowsPerLane);
}
flattenedIndexBase += slot.keys.size();
@@ -4795,9 +4940,10 @@ FailureOr<Value> emitWholeBatchFragmentGroup(MaterializerState& state,
state,
targetClass,
destination,
static_cast<int64_t>(group.slotIndices.size()),
static_cast<int64_t>(group.packedRowOffsets.size()),
[&](Value flatIndex) -> FailureOr<Value> {
Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc);
Value packedRowOffset =
createIndexedIndexValue(state, targetClass.op, group.packedRowOffsets, flatIndex, loc);
FailureOr<Value> packed = materializeTensorValueForMaterializedClassUse(
state,
targetClass,
@@ -4806,7 +4952,7 @@ FailureOr<Value> emitWholeBatchFragmentGroup(MaterializerState& state,
"whole-batch packed fragment assembly tried to reuse a tensor from another materialized class");
if (failed(packed))
return failure();
return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc);
return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedRowOffset, loc);
},
[&](Value flatIndex) -> FailureOr<Value> {
return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc);
@@ -5520,28 +5666,25 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
return rejectNonLocalResolvedValue(*value);
if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) {
size_t laneCount = targetClass.cpus.size();
for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) {
if (!llvm::is_contained(slot.keys, *producer))
continue;
Value received = Value();
if (indexedRun->packed) {
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
received = getPackedSliceForRunIndex(state,
targetClass.op,
indexedRun->packed,
indexedRun->fragmentType,
slotIndex,
consumerInstance.op->getLoc());
FailureOr<Value> packed =
materializeIndexedBatchRunPackedValue(state, targetClass, *indexedRun, *producer, consumerInstance.op->getLoc());
if (failed(packed))
return failure();
received = *packed;
}
else {
size_t laneCount = targetClass.cpus.size();
MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount);
received =
appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc());
}
for (ProducerKey slotKey : slot.keys)
state.availableValues.record(slotKey, targetClass.id, received);
state.availableValues.record(*producer, targetClass.id, received);
return rejectNonLocalResolvedValue(received);
}
}
@@ -5661,13 +5804,31 @@ LogicalResult mapInputs(MaterializerState& state,
const ComputeInstance& instance,
IRMapping& mapper,
CloneIndexingContext indexing) {
auto mapResolvedInput = [&](Value resolved) -> FailureOr<Value> {
return materializeTensorValueForMaterializedClassUse(
state,
targetClass,
resolved,
targetClass.op,
"input mapping tried to reuse a tensor from another materialized class");
auto mapResolvedInput = [&](Value resolved,
Value logicalInput,
Operation* anchor,
std::optional<ProducerKey> producer) -> FailureOr<Value> {
auto logicalType = dyn_cast<RankedTensorType>(logicalInput.getType());
auto physicalType = dyn_cast<RankedTensorType>(resolved.getType());
if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape())
return materializeTensorValueForMaterializedClassUse(
state,
targetClass,
resolved,
anchor,
"input mapping tried to reuse a tensor from another materialized class",
producer);
return materializeLogicalTensorView(state,
targetClass,
MaterializedTensorView {
.value = resolved,
.logicalType = logicalType,
.physicalType = physicalType,
.layout = classifyMaterializedTensorLayout(logicalType, physicalType)},
anchor,
"input mapping tried to reuse a tensor from another materialized class",
producer);
};
Operation* op = instance.op;
@@ -5686,14 +5847,16 @@ LogicalResult mapInputs(MaterializerState& state,
auto inputArg = compute.getInputArgument(index);
if (!inputArg)
return compute.emitOpError("expected compute input block argument while materializing inputs");
FailureOr<Value> remapped = mapResolvedInput(*mapped);
std::optional<ProducerKey> producer = getInputRequestProducerKey(input, instance);
FailureOr<Value> remapped = mapResolvedInput(*mapped, input, compute, producer);
if (failed(remapped)) {
emitNonLocalMaterializedClassValueDiagnostic(
compute,
targetClass,
"mapInputs tried to append a tensor from another materialized class",
*mapped,
getInputRequestProducerKey(input, instance));
if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass))
emitNonLocalMaterializedClassValueDiagnostic(
compute,
targetClass,
"mapInputs tried to append a tensor from another materialized class",
*mapped,
producer);
return failure();
}
mapper.map(*inputArg, *remapped);
@@ -5733,13 +5896,37 @@ LogicalResult mapInputs(MaterializerState& state,
auto inputArg = batch.getInputArgument(index);
if (!inputArg)
return batch.emitOpError("expected compute_batch input block argument while materializing inputs");
FailureOr<Value> remapped = mapResolvedInput(*mapped);
std::optional<ProducerKey> producer = getInputRequestProducerKey(input, instance);
if (demand->kind == BatchInputDemandKind::ProjectedFragment) {
std::optional<AffineProjectedInputSliceMatch> match =
getProjectedInputSliceMatch(state, batch, static_cast<unsigned>(index));
auto mappedType = dyn_cast<RankedTensorType>((*mapped).getType());
if (!match || !mappedType || mappedType != match->fragmentType) {
InFlightDiagnostic diagnostic = batch.emitOpError(
"projected compute_batch input resolved to an incompatible fragment")
<< " #" << index << " resolvedType=" << (*mapped).getType();
if (match)
diagnostic << " expectedFragmentType=" << match->fragmentType;
if (producer) {
diagnostic << " from '" << producer->instance.op->getName()
<< "' laneStart=" << producer->instance.laneStart
<< " laneCount=" << producer->instance.laneCount
<< " resultIndex=" << producer->resultIndex;
}
return failure();
}
mapper.map(*inputArg, *mapped);
continue;
}
FailureOr<Value> remapped = mapResolvedInput(*mapped, input, batch, producer);
if (failed(remapped)) {
emitNonLocalMaterializedClassValueDiagnostic(batch,
targetClass,
"mapInputs tried to append a tensor from another materialized class",
*mapped,
getInputRequestProducerKey(input, instance));
if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass))
emitNonLocalMaterializedClassValueDiagnostic(batch,
targetClass,
"mapInputs tried to append a tensor from another materialized class",
*mapped,
producer);
return failure();
}
mapper.map(*inputArg, *remapped);
@@ -6605,6 +6792,59 @@ LogicalResult registerPackedRunValue(MaterializerState& state,
return success();
}
FailureOr<RankedTensorType> getPackedSlotFragmentType(Value packed, size_t keyCount) {
auto packedType = dyn_cast<RankedTensorType>(packed.getType());
if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0 || keyCount == 0)
return failure();
int64_t packedRows = packedType.getDimSize(0);
if (packedRows % static_cast<int64_t>(keyCount) != 0)
return failure();
SmallVector<int64_t, 4> fragmentShape(packedType.getShape());
fragmentShape[0] = packedRows / static_cast<int64_t>(keyCount);
return RankedTensorType::get(fragmentShape, packedType.getElementType(), packedType.getEncoding());
}
LogicalResult recordLocalPackedSlotValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<ProducerKey> keys,
Value packed) {
if (keys.empty())
return success();
Operation* sourceOp = keys.front().instance.op;
size_t resultIndex = keys.front().resultIndex;
for (ProducerKey key : keys) {
if (key.instance.op != sourceOp || key.resultIndex != resultIndex)
return materializedClass.op->emitError("local packed slot registration expects one producer result");
if (key.instance.laneCount != 1)
return materializedClass.op->emitError("local packed slot registration expects one lane per packed fragment");
}
FailureOr<RankedTensorType> fragmentType = getPackedSlotFragmentType(packed, keys.size());
if (failed(fragmentType)) {
for (ProducerKey key : keys)
state.availableValues.record(key, materializedClass.id, packed);
return success();
}
PackedScalarRunValue packedRun;
packedRun.targetClass = materializedClass.id;
packedRun.sourceOp = sourceOp;
packedRun.resultIndex = resultIndex;
packedRun.packed = packed;
packedRun.kind = PackedScalarRunKind::Materialized;
packedRun.fragmentType = *fragmentType;
PackedScalarRunSlot slot;
llvm::append_range(slot.keys, keys);
packedRun.slots.push_back(std::move(slot));
state.availableValues.recordPackedRun(std::move(packedRun));
return success();
}
LogicalResult emitPackedRunFanout(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ClassId> destinationClasses,