Vgg 16 works and also resnet
This commit is contained in:
+288
-48
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -96,11 +97,35 @@ std::optional<ProducerKey>
|
||||
getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex);
|
||||
std::optional<ProducerKey> getProjectedWholeBatchReplacementProducer(MaterializerState& state,
|
||||
tensor::ExtractSliceOp extract);
|
||||
LogicalResult recordLocalPackedSlotValue(MaterializerState& state,
|
||||
MaterializedClass& materializedClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value packed);
|
||||
FailureOr<Value> materializeProjectedWholeBatchExtractReplacement(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
tensor::ExtractSliceOp extract,
|
||||
ProducerKey producer,
|
||||
IRMapping* mapper = nullptr);
|
||||
enum class MaterializedLayoutKind {
|
||||
DenseLogical,
|
||||
RowPackedNCHWFromRows,
|
||||
Unknown
|
||||
};
|
||||
|
||||
struct MaterializedTensorView {
|
||||
Value value;
|
||||
RankedTensorType logicalType;
|
||||
RankedTensorType physicalType;
|
||||
MaterializedLayoutKind layout = MaterializedLayoutKind::Unknown;
|
||||
};
|
||||
|
||||
FailureOr<Value> materializeLogicalTensorView(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
MaterializedTensorView view,
|
||||
Operation* anchor,
|
||||
StringRef context,
|
||||
std::optional<ProducerKey> producer = std::nullopt,
|
||||
IRMapping* mapper = nullptr);
|
||||
bool isConstantLike(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
@@ -1484,6 +1509,66 @@ FailureOr<Value> materializeTensorValueForMaterializedClassUse(MaterializerState
|
||||
return appendInput(state, targetClass, value);
|
||||
}
|
||||
|
||||
MaterializedLayoutKind classifyMaterializedTensorLayout(RankedTensorType logicalType, RankedTensorType physicalType) {
|
||||
if (logicalType == physicalType)
|
||||
return MaterializedLayoutKind::DenseLogical;
|
||||
|
||||
if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape())
|
||||
return MaterializedLayoutKind::Unknown;
|
||||
|
||||
if (logicalType.getRank() != 4 || physicalType.getRank() != 4)
|
||||
return MaterializedLayoutKind::Unknown;
|
||||
if (logicalType.getElementType() != physicalType.getElementType()
|
||||
|| logicalType.getEncoding() != physicalType.getEncoding())
|
||||
return MaterializedLayoutKind::Unknown;
|
||||
|
||||
if (logicalType.getDimSize(0) != 1 || physicalType.getDimSize(2) != 1)
|
||||
return MaterializedLayoutKind::Unknown;
|
||||
if (logicalType.getDimSize(1) != physicalType.getDimSize(1)
|
||||
|| logicalType.getDimSize(2) != physicalType.getDimSize(0)
|
||||
|| logicalType.getDimSize(3) != physicalType.getDimSize(3))
|
||||
return MaterializedLayoutKind::Unknown;
|
||||
|
||||
return MaterializedLayoutKind::RowPackedNCHWFromRows;
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeLogicalTensorView(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
MaterializedTensorView view,
|
||||
Operation* anchor,
|
||||
StringRef context,
|
||||
std::optional<ProducerKey> producer,
|
||||
IRMapping* mapper) {
|
||||
FailureOr<Value> localizedValue = materializeTensorValueForMaterializedClassUse(
|
||||
state, targetClass, view.value, anchor, context, producer, mapper);
|
||||
if (failed(localizedValue))
|
||||
return failure();
|
||||
|
||||
if (view.layout == MaterializedLayoutKind::DenseLogical)
|
||||
return *localizedValue;
|
||||
|
||||
if (view.layout == MaterializedLayoutKind::RowPackedNCHWFromRows) {
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
Value empty = tensor::EmptyOp::create(state.rewriter,
|
||||
anchor->getLoc(),
|
||||
view.logicalType.getShape(),
|
||||
view.logicalType.getElementType())
|
||||
.getResult();
|
||||
auto permutation = DenseI64ArrayAttr::get(anchor->getContext(), {2, 1, 0, 3});
|
||||
auto transpose =
|
||||
linalg::TransposeOp::create(state.rewriter, anchor->getLoc(), *localizedValue, empty, permutation);
|
||||
return transpose.getResult().front();
|
||||
}
|
||||
|
||||
InFlightDiagnostic diagnostic = anchor->emitError("materialized tensor replacement changed physical layout")
|
||||
<< " logicalType=" << view.logicalType << " physicalType=" << view.physicalType;
|
||||
if (producer) {
|
||||
diagnostic << " producerOp='" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart
|
||||
<< " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::optional<Value> mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass,
|
||||
Operation* anchor,
|
||||
BlockArgument externalArg) {
|
||||
@@ -1870,6 +1955,58 @@ Value getPackedSliceForDynamicRunIndex(
|
||||
return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0));
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeIndexedBatchRunPackedValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
IndexedBatchRunValue& run,
|
||||
ProducerKey key,
|
||||
Location loc) {
|
||||
if (!run.packed)
|
||||
return failure();
|
||||
|
||||
size_t flattenedIndexBase = 0;
|
||||
for (const PackedScalarRunSlot& slot : run.slots) {
|
||||
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
auto keyIt = llvm::find(slot.keys, key);
|
||||
if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key))) {
|
||||
flattenedIndexBase += slot.keys.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
FailureOr<Value> packed = materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
targetClass,
|
||||
run.packed,
|
||||
targetClass.op,
|
||||
"indexed batch run packed resolution tried to reuse a tensor from another materialized class");
|
||||
if (failed(packed))
|
||||
return failure();
|
||||
|
||||
if (!contiguousKey) {
|
||||
size_t globalFragmentIndex = flattenedIndexBase + static_cast<size_t>(std::distance(slot.keys.begin(), keyIt));
|
||||
return getPackedSliceForRunIndex(state, targetClass.op, *packed, run.fragmentType, globalFragmentIndex, loc);
|
||||
}
|
||||
|
||||
FailureOr<RankedTensorType> slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size());
|
||||
if (failed(slotPackedType))
|
||||
return failure();
|
||||
|
||||
int64_t rowOffset = static_cast<int64_t>(flattenedIndexBase) * run.fragmentType.getDimSize(0);
|
||||
Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset);
|
||||
Value slotPacked =
|
||||
createDim0ExtractSlice(state, loc, *packed, firstOffset, (*slotPackedType).getDimSize(0));
|
||||
|
||||
if (*contiguousKey == key)
|
||||
return slotPacked;
|
||||
|
||||
std::optional<Value> sliced = extractPackedProducerSlice(state, targetClass, *contiguousKey, slotPacked, key);
|
||||
if (!sliced)
|
||||
return failure();
|
||||
return *sliced;
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
using IndexedFragmentBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
using IndexedInsertOffsetBuilder = llvm::function_ref<FailureOr<Value>(Value flatIndex)>;
|
||||
|
||||
@@ -2017,11 +2154,15 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
||||
if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex)
|
||||
continue;
|
||||
|
||||
size_t flattenedIndexBase = 0;
|
||||
for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) {
|
||||
std::optional<ProducerKey> contiguousKey = getContiguousProducerRangeForKeys(slot.keys);
|
||||
auto keyIt = llvm::find(slot.keys, key);
|
||||
if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key)))
|
||||
{
|
||||
flattenedIndexBase += slot.keys.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
MaterializedClass& materializedClass = state.classes[classId];
|
||||
state.rewriter.setInsertionPoint(materializedClass.body->getTerminator());
|
||||
@@ -2032,11 +2173,12 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
||||
return std::nullopt;
|
||||
|
||||
if (!contiguousKey) {
|
||||
size_t globalFragmentIndex = flattenedIndexBase + static_cast<size_t>(std::distance(slot.keys.begin(), keyIt));
|
||||
Value sliced = getPackedSliceForRunIndex(state,
|
||||
materializedClass.op,
|
||||
*packed,
|
||||
run.fragmentType,
|
||||
static_cast<size_t>(std::distance(slot.keys.begin(), keyIt)),
|
||||
globalFragmentIndex,
|
||||
(*packed).getLoc());
|
||||
record(key, classId, sliced);
|
||||
return sliced;
|
||||
@@ -2046,8 +2188,10 @@ std::optional<Value> AvailableValueStore::lookupPackedRun(MaterializerState& sta
|
||||
if (failed(slotPackedType))
|
||||
return std::nullopt;
|
||||
|
||||
int64_t rowOffset = static_cast<int64_t>(flattenedIndexBase) * run.fragmentType.getDimSize(0);
|
||||
Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset);
|
||||
Value slotPacked =
|
||||
getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc());
|
||||
createDim0ExtractSlice(state, (*packed).getLoc(), *packed, firstOffset, (*slotPackedType).getDimSize(0));
|
||||
|
||||
if (*contiguousKey == key) {
|
||||
record(key, classId, slotPacked);
|
||||
@@ -3862,8 +4006,10 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
if (sourceClass.id == targetClass.id) {
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues.record(key, targetClass.id, payload);
|
||||
if (keys.size() == 1)
|
||||
state.availableValues.record(keys.front(), targetClass.id, payload);
|
||||
else if (failed(recordLocalPackedSlotValue(state, targetClass, keys, payload)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4080,8 +4226,10 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
if (!recordedProjectedHostFragments && failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
|
||||
return failure();
|
||||
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues.record(key, sourceClass.id, payload);
|
||||
if (keys.size() == 1)
|
||||
state.availableValues.record(keys.front(), sourceClass.id, payload);
|
||||
else if (failed(recordLocalPackedSlotValue(state, sourceClass, keys, payload)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -4108,7 +4256,7 @@ struct WholeBatchFragmentGroup {
|
||||
SmallVector<int64_t, 16> sourceLanes;
|
||||
Value packed;
|
||||
RankedTensorType slotPackedType;
|
||||
SmallVector<int64_t, 16> slotIndices;
|
||||
SmallVector<int64_t, 16> packedRowOffsets;
|
||||
SmallVector<std::pair<Value, int64_t>, 16> directFragments;
|
||||
SmallVector<Operation*, 16> redundantReceives;
|
||||
};
|
||||
@@ -4232,13 +4380,10 @@ FailureOr<Value> extractPackedSlotForIndex(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
Value packed,
|
||||
RankedTensorType slotPackedType,
|
||||
Value slotIndex,
|
||||
Value packedRowOffset,
|
||||
Location loc) {
|
||||
FailureOr<Value> firstOffset =
|
||||
scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc);
|
||||
if (failed(firstOffset))
|
||||
return failure();
|
||||
return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0));
|
||||
return createDim0ExtractSliceInClass(
|
||||
state, targetClass, loc, packed, packedRowOffset, slotPackedType.getDimSize(0));
|
||||
}
|
||||
|
||||
SmallVector<ProducerKey, 16> flattenPackedScalarRunKeys(const PackedScalarRunValue& run) {
|
||||
@@ -4655,7 +4800,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
|
||||
if (failed(slotPackedType))
|
||||
return failure();
|
||||
WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType);
|
||||
group.slotIndices.push_back(slotIndex);
|
||||
group.packedRowOffsets.push_back(static_cast<int64_t>(flattenedIndexBase) * plan.rowsPerLane);
|
||||
group.outputOffsets.push_back(static_cast<int64_t>(contiguousKey->instance.laneStart) * plan.rowsPerLane);
|
||||
flattenedIndexBase += slot.keys.size();
|
||||
continue;
|
||||
@@ -4663,7 +4808,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state,
|
||||
|
||||
WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType);
|
||||
for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) {
|
||||
group.slotIndices.push_back(flattenedIndexBase + keyIndex);
|
||||
group.packedRowOffsets.push_back(static_cast<int64_t>(flattenedIndexBase + keyIndex) * plan.rowsPerLane);
|
||||
group.outputOffsets.push_back(static_cast<int64_t>(fragmentKey.instance.laneStart) * plan.rowsPerLane);
|
||||
}
|
||||
flattenedIndexBase += slot.keys.size();
|
||||
@@ -4795,9 +4940,10 @@ FailureOr<Value> emitWholeBatchFragmentGroup(MaterializerState& state,
|
||||
state,
|
||||
targetClass,
|
||||
destination,
|
||||
static_cast<int64_t>(group.slotIndices.size()),
|
||||
static_cast<int64_t>(group.packedRowOffsets.size()),
|
||||
[&](Value flatIndex) -> FailureOr<Value> {
|
||||
Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc);
|
||||
Value packedRowOffset =
|
||||
createIndexedIndexValue(state, targetClass.op, group.packedRowOffsets, flatIndex, loc);
|
||||
FailureOr<Value> packed = materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
targetClass,
|
||||
@@ -4806,7 +4952,7 @@ FailureOr<Value> emitWholeBatchFragmentGroup(MaterializerState& state,
|
||||
"whole-batch packed fragment assembly tried to reuse a tensor from another materialized class");
|
||||
if (failed(packed))
|
||||
return failure();
|
||||
return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc);
|
||||
return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedRowOffset, loc);
|
||||
},
|
||||
[&](Value flatIndex) -> FailureOr<Value> {
|
||||
return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc);
|
||||
@@ -5520,28 +5666,25 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
return rejectNonLocalResolvedValue(*value);
|
||||
|
||||
if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) {
|
||||
size_t laneCount = targetClass.cpus.size();
|
||||
for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) {
|
||||
if (!llvm::is_contained(slot.keys, *producer))
|
||||
continue;
|
||||
|
||||
Value received = Value();
|
||||
if (indexedRun->packed) {
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
received = getPackedSliceForRunIndex(state,
|
||||
targetClass.op,
|
||||
indexedRun->packed,
|
||||
indexedRun->fragmentType,
|
||||
slotIndex,
|
||||
consumerInstance.op->getLoc());
|
||||
FailureOr<Value> packed =
|
||||
materializeIndexedBatchRunPackedValue(state, targetClass, *indexedRun, *producer, consumerInstance.op->getLoc());
|
||||
if (failed(packed))
|
||||
return failure();
|
||||
received = *packed;
|
||||
}
|
||||
else {
|
||||
size_t laneCount = targetClass.cpus.size();
|
||||
MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount);
|
||||
received =
|
||||
appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc());
|
||||
}
|
||||
for (ProducerKey slotKey : slot.keys)
|
||||
state.availableValues.record(slotKey, targetClass.id, received);
|
||||
state.availableValues.record(*producer, targetClass.id, received);
|
||||
return rejectNonLocalResolvedValue(received);
|
||||
}
|
||||
}
|
||||
@@ -5661,13 +5804,31 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
const ComputeInstance& instance,
|
||||
IRMapping& mapper,
|
||||
CloneIndexingContext indexing) {
|
||||
auto mapResolvedInput = [&](Value resolved) -> FailureOr<Value> {
|
||||
return materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
targetClass,
|
||||
resolved,
|
||||
targetClass.op,
|
||||
"input mapping tried to reuse a tensor from another materialized class");
|
||||
auto mapResolvedInput = [&](Value resolved,
|
||||
Value logicalInput,
|
||||
Operation* anchor,
|
||||
std::optional<ProducerKey> producer) -> FailureOr<Value> {
|
||||
auto logicalType = dyn_cast<RankedTensorType>(logicalInput.getType());
|
||||
auto physicalType = dyn_cast<RankedTensorType>(resolved.getType());
|
||||
if (!logicalType || !physicalType || !logicalType.hasStaticShape() || !physicalType.hasStaticShape())
|
||||
return materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
targetClass,
|
||||
resolved,
|
||||
anchor,
|
||||
"input mapping tried to reuse a tensor from another materialized class",
|
||||
producer);
|
||||
|
||||
return materializeLogicalTensorView(state,
|
||||
targetClass,
|
||||
MaterializedTensorView {
|
||||
.value = resolved,
|
||||
.logicalType = logicalType,
|
||||
.physicalType = physicalType,
|
||||
.layout = classifyMaterializedTensorLayout(logicalType, physicalType)},
|
||||
anchor,
|
||||
"input mapping tried to reuse a tensor from another materialized class",
|
||||
producer);
|
||||
};
|
||||
|
||||
Operation* op = instance.op;
|
||||
@@ -5686,14 +5847,16 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
auto inputArg = compute.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return compute.emitOpError("expected compute input block argument while materializing inputs");
|
||||
FailureOr<Value> remapped = mapResolvedInput(*mapped);
|
||||
std::optional<ProducerKey> producer = getInputRequestProducerKey(input, instance);
|
||||
FailureOr<Value> remapped = mapResolvedInput(*mapped, input, compute, producer);
|
||||
if (failed(remapped)) {
|
||||
emitNonLocalMaterializedClassValueDiagnostic(
|
||||
compute,
|
||||
targetClass,
|
||||
"mapInputs tried to append a tensor from another materialized class",
|
||||
*mapped,
|
||||
getInputRequestProducerKey(input, instance));
|
||||
if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass))
|
||||
emitNonLocalMaterializedClassValueDiagnostic(
|
||||
compute,
|
||||
targetClass,
|
||||
"mapInputs tried to append a tensor from another materialized class",
|
||||
*mapped,
|
||||
producer);
|
||||
return failure();
|
||||
}
|
||||
mapper.map(*inputArg, *remapped);
|
||||
@@ -5733,13 +5896,37 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
auto inputArg = batch.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return batch.emitOpError("expected compute_batch input block argument while materializing inputs");
|
||||
FailureOr<Value> remapped = mapResolvedInput(*mapped);
|
||||
std::optional<ProducerKey> producer = getInputRequestProducerKey(input, instance);
|
||||
if (demand->kind == BatchInputDemandKind::ProjectedFragment) {
|
||||
std::optional<AffineProjectedInputSliceMatch> match =
|
||||
getProjectedInputSliceMatch(state, batch, static_cast<unsigned>(index));
|
||||
auto mappedType = dyn_cast<RankedTensorType>((*mapped).getType());
|
||||
if (!match || !mappedType || mappedType != match->fragmentType) {
|
||||
InFlightDiagnostic diagnostic = batch.emitOpError(
|
||||
"projected compute_batch input resolved to an incompatible fragment")
|
||||
<< " #" << index << " resolvedType=" << (*mapped).getType();
|
||||
if (match)
|
||||
diagnostic << " expectedFragmentType=" << match->fragmentType;
|
||||
if (producer) {
|
||||
diagnostic << " from '" << producer->instance.op->getName()
|
||||
<< "' laneStart=" << producer->instance.laneStart
|
||||
<< " laneCount=" << producer->instance.laneCount
|
||||
<< " resultIndex=" << producer->resultIndex;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
mapper.map(*inputArg, *mapped);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> remapped = mapResolvedInput(*mapped, input, batch, producer);
|
||||
if (failed(remapped)) {
|
||||
emitNonLocalMaterializedClassValueDiagnostic(batch,
|
||||
targetClass,
|
||||
"mapInputs tried to append a tensor from another materialized class",
|
||||
*mapped,
|
||||
getInputRequestProducerKey(input, instance));
|
||||
if (isTensorValueDefinedInDifferentMaterializedClass(*mapped, targetClass))
|
||||
emitNonLocalMaterializedClassValueDiagnostic(batch,
|
||||
targetClass,
|
||||
"mapInputs tried to append a tensor from another materialized class",
|
||||
*mapped,
|
||||
producer);
|
||||
return failure();
|
||||
}
|
||||
mapper.map(*inputArg, *remapped);
|
||||
@@ -6605,6 +6792,59 @@ LogicalResult registerPackedRunValue(MaterializerState& state,
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<RankedTensorType> getPackedSlotFragmentType(Value packed, size_t keyCount) {
|
||||
auto packedType = dyn_cast<RankedTensorType>(packed.getType());
|
||||
if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0 || keyCount == 0)
|
||||
return failure();
|
||||
|
||||
int64_t packedRows = packedType.getDimSize(0);
|
||||
if (packedRows % static_cast<int64_t>(keyCount) != 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t, 4> fragmentShape(packedType.getShape());
|
||||
fragmentShape[0] = packedRows / static_cast<int64_t>(keyCount);
|
||||
return RankedTensorType::get(fragmentShape, packedType.getElementType(), packedType.getEncoding());
|
||||
}
|
||||
|
||||
LogicalResult recordLocalPackedSlotValue(MaterializerState& state,
|
||||
MaterializedClass& materializedClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value packed) {
|
||||
if (keys.empty())
|
||||
return success();
|
||||
|
||||
Operation* sourceOp = keys.front().instance.op;
|
||||
size_t resultIndex = keys.front().resultIndex;
|
||||
for (ProducerKey key : keys) {
|
||||
if (key.instance.op != sourceOp || key.resultIndex != resultIndex)
|
||||
return materializedClass.op->emitError("local packed slot registration expects one producer result");
|
||||
if (key.instance.laneCount != 1)
|
||||
return materializedClass.op->emitError("local packed slot registration expects one lane per packed fragment");
|
||||
}
|
||||
|
||||
FailureOr<RankedTensorType> fragmentType = getPackedSlotFragmentType(packed, keys.size());
|
||||
if (failed(fragmentType)) {
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues.record(key, materializedClass.id, packed);
|
||||
return success();
|
||||
}
|
||||
|
||||
PackedScalarRunValue packedRun;
|
||||
packedRun.targetClass = materializedClass.id;
|
||||
packedRun.sourceOp = sourceOp;
|
||||
packedRun.resultIndex = resultIndex;
|
||||
packedRun.packed = packed;
|
||||
packedRun.kind = PackedScalarRunKind::Materialized;
|
||||
packedRun.fragmentType = *fragmentType;
|
||||
|
||||
PackedScalarRunSlot slot;
|
||||
llvm::append_range(slot.keys, keys);
|
||||
packedRun.slots.push_back(std::move(slot));
|
||||
|
||||
state.availableValues.recordPackedRun(std::move(packedRun));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult emitPackedRunFanout(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ClassId> destinationClasses,
|
||||
|
||||
Reference in New Issue
Block a user