This commit is contained in:
@@ -1220,9 +1220,38 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) {
|
||||||
|
if (value == expected) {
|
||||||
|
usesExpected = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchPattern(value, m_Constant()))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
|
if (!affineApply || affineApply.getAffineMap().getNumResults() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
bool nestedUsesExpected = false;
|
||||||
|
for (Value operand : affineApply.getMapOperands()) {
|
||||||
|
bool operandUsesExpected = false;
|
||||||
|
if (!isLaneProjectedOffsetValue(operand, expected, operandUsesExpected))
|
||||||
|
return false;
|
||||||
|
nestedUsesExpected = nestedUsesExpected || operandUsesExpected;
|
||||||
|
}
|
||||||
|
|
||||||
|
usesExpected = usesExpected || nestedUsesExpected;
|
||||||
|
return nestedUsesExpected;
|
||||||
|
}
|
||||||
|
|
||||||
bool isValueOffset(OpFoldResult offset, Value expected) {
|
bool isValueOffset(OpFoldResult offset, Value expected) {
|
||||||
auto value = dyn_cast<Value>(offset);
|
auto value = dyn_cast<Value>(offset);
|
||||||
return value && value == expected;
|
if (!value)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
bool usesExpected = false;
|
||||||
|
return isLaneProjectedOffsetValue(value, expected, usesExpected) && usesExpected;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
||||||
|
|||||||
@@ -89,9 +89,18 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
static FailureOr<int64_t>
|
||||||
auto offsetValue = llvm::dyn_cast<Value>(offset);
|
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
|
||||||
return offsetValue == laneArg;
|
|
||||||
|
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
|
||||||
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
|
std::optional<uint32_t> lane,
|
||||||
|
Value laneArg);
|
||||||
|
|
||||||
|
bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) {
|
||||||
|
DenseMap<Value, int64_t> bindings;
|
||||||
|
return succeeded(evaluateIndexLike(offset, bindings, /*lane=*/0, laneArg))
|
||||||
|
&& succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||||
@@ -110,7 +119,7 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
|
|||||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||||
if (!extract || extract.getSource() != *inputArg)
|
if (!extract || extract.getSource() != *inputArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
|
if (extract.getMixedOffsets().empty() || !isProjectedBatchOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||||
|
|||||||
Reference in New Issue
Block a user