This commit is contained in:
@@ -97,11 +97,17 @@ static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewr
|
||||
value.getLoc(),
|
||||
outputType,
|
||||
value,
|
||||
ValueRange {},
|
||||
rewriter.getStringAttr(kLogicalLayout),
|
||||
rewriter.getStringAttr(kRowStripLayout),
|
||||
rewriter.getDenseI64ArrayAttr(offsets),
|
||||
rewriter.getDenseI64ArrayAttr(sizes),
|
||||
rewriter.getStringAttr(kRowStripIndexMap));
|
||||
rewriter.getStringAttr(kRowStripIndexMap),
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
static void materializeDenseUses(IRRewriter& rewriter,
|
||||
|
||||
@@ -233,15 +233,21 @@ def SpatReluPlanOp : SpatOp<"relu_plan", []> {
|
||||
}
|
||||
|
||||
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
||||
let summary = "Passive logical-to-physical layout selection record";
|
||||
let summary = "Logical-to-physical layout record or explicit fragment assembly";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input,
|
||||
Variadic<SpatTensor>:$fragments,
|
||||
StrAttr:$logicalLayout,
|
||||
StrAttr:$physicalLayout,
|
||||
DenseI64ArrayAttr:$fragmentOffsets,
|
||||
DenseI64ArrayAttr:$fragmentSizes,
|
||||
StrAttr:$indexMap
|
||||
StrAttr:$indexMap,
|
||||
OptionalAttr<StrAttr>:$mode,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$fragmentOperandIndices,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$fragmentStrides,
|
||||
OptionalAttr<StrAttr>:$conflictPolicy,
|
||||
OptionalAttr<StrAttr>:$coveragePolicy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
||||
@@ -383,7 +383,7 @@ LogicalResult SpatConcatOp::verify() {
|
||||
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
|
||||
|
||||
static bool isKnownPhysicalLayout(StringRef layout) {
|
||||
return layout == "dense_nchw" || layout == "nchw_row_strip";
|
||||
return layout == "dense_nchw" || layout == "nchw_row_strip" || layout == "fragmented";
|
||||
}
|
||||
|
||||
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
|
||||
@@ -437,7 +437,9 @@ LogicalResult SpatReluPlanOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult SpatReconciliatorOp::verify() {
|
||||
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
|
||||
auto modeAttr = getModeAttr();
|
||||
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
|
||||
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
|
||||
return failure();
|
||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||
return emitError("requires a known logical layout");
|
||||
@@ -452,23 +454,154 @@ LogicalResult SpatReconciliatorOp::verify() {
|
||||
auto sizes = getFragmentSizes();
|
||||
if (offsets.size() != sizes.size())
|
||||
return emitError("fragment offset and size arrays must have the same length");
|
||||
int64_t rank = logicalType.getRank();
|
||||
if (offsets.empty())
|
||||
return success();
|
||||
|
||||
int64_t rank = logicalType.getRank();
|
||||
if (rank <= 0 || offsets.size() % rank != 0)
|
||||
return emitError("fragment metadata must be a whole number of rank-sized fragments");
|
||||
|
||||
ArrayRef<int64_t> shape = logicalType.getShape();
|
||||
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
|
||||
int64_t dim = index % rank;
|
||||
int64_t offset = offsets[index];
|
||||
int64_t size = sizes[index];
|
||||
if (offset < 0 || size < 0)
|
||||
return emitError("fragment offsets and sizes must be non-negative");
|
||||
int64_t logicalDim = shape[dim];
|
||||
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
|
||||
return emitError("fragment bounds must stay within the logical tensor shape");
|
||||
auto verifyBoundsOnly = [&](ArrayRef<int64_t> strideValues) -> LogicalResult {
|
||||
ArrayRef<int64_t> shape = logicalType.getShape();
|
||||
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
|
||||
int64_t dim = index % rank;
|
||||
int64_t offset = offsets[index];
|
||||
int64_t size = sizes[index];
|
||||
int64_t stride = strideValues.empty() ? 1 : strideValues[index];
|
||||
if (offset < 0 || size < 0 || stride < 0)
|
||||
return emitError("fragment offsets, sizes, and strides must be non-negative");
|
||||
int64_t logicalDim = shape[dim];
|
||||
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
|
||||
return emitError("fragment bounds must stay within the logical tensor shape");
|
||||
if (stride != 1)
|
||||
return emitError("fragment assembly currently requires unit strides");
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (!isFragmentAssembly) {
|
||||
if (failed(verifyBoundsOnly({})))
|
||||
return failure();
|
||||
if (!getFragments().empty())
|
||||
return emitError("legacy reconciliator does not accept extra fragment operands");
|
||||
if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr())
|
||||
return emitError("legacy reconciliator does not accept fragment assembly attributes");
|
||||
return success();
|
||||
}
|
||||
|
||||
auto stridesAttr = getFragmentStridesAttr();
|
||||
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
|
||||
if (!operandIndicesAttr)
|
||||
return emitError("fragment assembly reconciliator requires fragment operand indices");
|
||||
if (!stridesAttr)
|
||||
return emitError("fragment assembly reconciliator requires fragment strides");
|
||||
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
|
||||
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
|
||||
if (strides.size() != offsets.size())
|
||||
return emitError("fragment stride and offset arrays must have the same length");
|
||||
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
|
||||
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
|
||||
if (getConflictPolicy() != "disjoint")
|
||||
return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\"");
|
||||
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
|
||||
return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\"");
|
||||
|
||||
SmallVector<Value> operands;
|
||||
operands.push_back(getInput());
|
||||
llvm::append_range(operands, getFragments());
|
||||
int64_t operandCount = static_cast<int64_t>(operands.size());
|
||||
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
|
||||
if (operandCount == 0)
|
||||
return emitError("fragment assembly reconciliator requires at least one operand");
|
||||
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
|
||||
return emitError("fragment assembly metadata count must match operand count * result rank");
|
||||
if (failed(verifyBoundsOnly(strides)))
|
||||
return failure();
|
||||
|
||||
SmallVector<std::pair<SmallVector<int64_t, 4>, SmallVector<int64_t, 4>>, 8> slices;
|
||||
slices.reserve(static_cast<size_t>(fragmentCount));
|
||||
SmallVector<SmallVector<SmallVector<int64_t, 4>, 4>, 8> sizesByOperand(static_cast<size_t>(operandCount));
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= operandCount)
|
||||
return emitError("fragment assembly operand index is out of range");
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
return emitError("fragment assembly reconciliator requires static ranked tensor operands");
|
||||
if (operandType.getRank() != rank)
|
||||
return emitError("fragment assembly reconciliator requires operand/result rank match");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
fragmentOffsets.reserve(rank);
|
||||
fragmentSizes.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
fragmentOffsets.push_back(offsets[flatIndex]);
|
||||
fragmentSizes.push_back(sizes[flatIndex]);
|
||||
}
|
||||
|
||||
sizesByOperand[static_cast<size_t>(operandIndex)].push_back(fragmentSizes);
|
||||
|
||||
for (const auto& [existingOffsets, existingSizes] : slices) {
|
||||
bool overlaps = true;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t begin = fragmentOffsets[dim];
|
||||
int64_t end = begin + fragmentSizes[dim];
|
||||
int64_t existingBegin = existingOffsets[dim];
|
||||
int64_t existingEnd = existingBegin + existingSizes[dim];
|
||||
if (end <= existingBegin || existingEnd <= begin) {
|
||||
overlaps = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (overlaps)
|
||||
return emitError("fragment assembly reconciliator requires disjoint static slices");
|
||||
}
|
||||
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
|
||||
}
|
||||
|
||||
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
|
||||
if (sizesByOperand[static_cast<size_t>(operandIndex)].empty())
|
||||
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
|
||||
|
||||
auto operandType = cast<RankedTensorType>(operands[operandIndex].getType());
|
||||
ArrayRef<int64_t> operandShape = operandType.getShape();
|
||||
auto& fragmentShapes = sizesByOperand[static_cast<size_t>(operandIndex)];
|
||||
if (fragmentShapes.size() == 1) {
|
||||
if (!llvm::equal(operandShape, fragmentShapes.front()))
|
||||
return emitError("single-fragment reconciliator operand shape must match declared fragment size");
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> fragmentShape = fragmentShapes.front();
|
||||
for (ArrayRef<int64_t> otherShape : fragmentShapes)
|
||||
if (!llvm::equal(fragmentShape, otherShape))
|
||||
return emitError("packed reconciliator operand requires equal fragment sizes per operand");
|
||||
if (llvm::equal(operandShape, fragmentShape))
|
||||
continue;
|
||||
if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front()))
|
||||
return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions");
|
||||
if (operandShape.front() != static_cast<int64_t>(fragmentShapes.size()) * fragmentShape.front())
|
||||
return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size");
|
||||
}
|
||||
|
||||
if (getCoveragePolicy() == "complete") {
|
||||
int64_t covered = 0;
|
||||
int64_t logicalElements = 1;
|
||||
for (int64_t dimSize : logicalType.getShape()) {
|
||||
if (ShapedType::isDynamic(dimSize))
|
||||
return emitError("fragment assembly complete coverage requires static result shape");
|
||||
logicalElements *= dimSize;
|
||||
}
|
||||
for (const auto& [ignoredOffsets, fragmentSizes] : slices) {
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dimSize : fragmentSizes)
|
||||
fragmentElements *= dimSize;
|
||||
covered += fragmentElements;
|
||||
}
|
||||
if (covered != logicalElements)
|
||||
return emitError("fragment assembly complete coverage must cover the whole result exactly");
|
||||
}
|
||||
|
||||
return success();
|
||||
|
||||
+923
-2846
File diff suppressed because it is too large
Load Diff
+9510
File diff suppressed because it is too large
Load Diff
+7548
File diff suppressed because it is too large
Load Diff
+128
@@ -0,0 +1,128 @@
|
||||
--- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000
|
||||
+++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000
|
||||
@@ -4112,104 +4112,8 @@
|
||||
Value originalOutput,
|
||||
Location loc);
|
||||
|
||||
-FailureOr<SmallVector<OpFoldResult, 4>> rematerializeProjectionIndexListForBatchHostOutput(
|
||||
- MaterializerState& state,
|
||||
- MaterializedClass& sourceClass,
|
||||
- ArrayRef<OpFoldResult> values,
|
||||
- IRMapping& mapper,
|
||||
- Location loc) {
|
||||
- SmallVector<OpFoldResult, 4> localized;
|
||||
- localized.reserve(values.size());
|
||||
- for (OpFoldResult value : values) {
|
||||
- FailureOr<OpFoldResult> remapped =
|
||||
- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper);
|
||||
- if (failed(remapped))
|
||||
- return failure();
|
||||
- localized.push_back(*remapped);
|
||||
- }
|
||||
- return localized;
|
||||
-}
|
||||
-
|
||||
-LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state,
|
||||
- MaterializedClass& sourceClass,
|
||||
- Value originalOutput,
|
||||
- Value payload,
|
||||
- Value destination,
|
||||
- ArrayRef<ProducerKey> keys,
|
||||
- Location loc) {
|
||||
- auto originalResult = dyn_cast<OpResult>(originalOutput);
|
||||
- if (!originalResult)
|
||||
- return failure();
|
||||
-
|
||||
- auto sourceBatch = dyn_cast_or_null<SpatComputeBatch>(originalResult.getOwner());
|
||||
- if (!sourceBatch || sourceBatch.getNumResults() == 0)
|
||||
- return failure();
|
||||
-
|
||||
- FailureOr<tensor::ParallelInsertSliceOp> projection =
|
||||
- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber());
|
||||
- if (failed(projection))
|
||||
- return failure();
|
||||
-
|
||||
- auto sourceLaneArg = sourceBatch.getLaneArgument();
|
||||
- if (!sourceLaneArg)
|
||||
- return failure();
|
||||
-
|
||||
- auto materializedBatch = dyn_cast<SpatScheduledComputeBatch>(sourceClass.op);
|
||||
- if (!materializedBatch)
|
||||
- return failure();
|
||||
-
|
||||
- auto materializedLaneArg = materializedBatch.getLaneArgument();
|
||||
- if (!materializedLaneArg)
|
||||
- return failure();
|
||||
-
|
||||
- if (keys.size() != sourceClass.cpus.size())
|
||||
- return failure();
|
||||
-
|
||||
- SmallVector<int64_t, 8> logicalLanes;
|
||||
- logicalLanes.reserve(keys.size());
|
||||
- for (ProducerKey key : keys) {
|
||||
- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber())
|
||||
- return failure();
|
||||
- logicalLanes.push_back(key.instance.laneStart);
|
||||
- }
|
||||
-
|
||||
- IRMapping mapper;
|
||||
- Value logicalLane = createIndexedIndexValue(state,
|
||||
- sourceClass.op,
|
||||
- ArrayRef<int64_t>(logicalLanes),
|
||||
- *materializedLaneArg,
|
||||
- loc,
|
||||
- static_cast<int64_t>(sourceClass.cpus.size()),
|
||||
- /*allowExhaustiveTiledSearch=*/false);
|
||||
- mapper.map(*sourceLaneArg, logicalLane);
|
||||
-
|
||||
- FailureOr<SmallVector<OpFoldResult, 4>> offsets =
|
||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
||||
- state, sourceClass, projection->getMixedOffsets(), mapper, loc);
|
||||
- if (failed(offsets))
|
||||
- return failure();
|
||||
- FailureOr<SmallVector<OpFoldResult, 4>> sizes =
|
||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
||||
- state, sourceClass, projection->getMixedSizes(), mapper, loc);
|
||||
- if (failed(sizes))
|
||||
- return failure();
|
||||
- FailureOr<SmallVector<OpFoldResult, 4>> strides =
|
||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
||||
- state, sourceClass, projection->getMixedStrides(), mapper, loc);
|
||||
- if (failed(strides))
|
||||
- return failure();
|
||||
-
|
||||
- tensor::ParallelInsertSliceOp::create(
|
||||
- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides);
|
||||
- return success();
|
||||
-}
|
||||
-
|
||||
LogicalResult
|
||||
-setHostOutputValue(MaterializerState& state,
|
||||
- MaterializedClass& sourceClass,
|
||||
- Value originalOutput,
|
||||
- Value payload,
|
||||
- ArrayRef<ProducerKey> keys = {}) {
|
||||
+setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
|
||||
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
|
||||
if (resultIt == sourceClass.hostOutputToResultIndex.end())
|
||||
return sourceClass.op->emitError("missing host result slot for materialized output")
|
||||
@@ -4253,10 +4157,6 @@
|
||||
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
||||
|
||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
- if (succeeded(createProjectionAwareBatchHostInsert(
|
||||
- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc())))
|
||||
- return success();
|
||||
-
|
||||
createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg);
|
||||
return success();
|
||||
}
|
||||
@@ -4276,7 +4176,7 @@
|
||||
|
||||
MaterializedClass& ownerClass = state.classes[ownerIt->second];
|
||||
if (sourceClass.id == ownerClass.id)
|
||||
- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys);
|
||||
+ return setHostOutputValue(state, ownerClass, originalOutput, payload);
|
||||
|
||||
// Keep the old deadlock-free communication discipline: only scalar-to-scalar
|
||||
// host-owner forwarding is introduced here. Batch host publication remains on
|
||||
Reference in New Issue
Block a user