This commit is contained in:
@@ -1220,9 +1220,38 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) {
|
||||
if (value == expected) {
|
||||
usesExpected = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (matchPattern(value, m_Constant()))
|
||||
return true;
|
||||
|
||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||
if (!affineApply || affineApply.getAffineMap().getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
bool nestedUsesExpected = false;
|
||||
for (Value operand : affineApply.getMapOperands()) {
|
||||
bool operandUsesExpected = false;
|
||||
if (!isLaneProjectedOffsetValue(operand, expected, operandUsesExpected))
|
||||
return false;
|
||||
nestedUsesExpected = nestedUsesExpected || operandUsesExpected;
|
||||
}
|
||||
|
||||
usesExpected = usesExpected || nestedUsesExpected;
|
||||
return nestedUsesExpected;
|
||||
}
|
||||
|
||||
bool isValueOffset(OpFoldResult offset, Value expected) {
|
||||
auto value = dyn_cast<Value>(offset);
|
||||
return value && value == expected;
|
||||
if (!value)
|
||||
return false;
|
||||
|
||||
bool usesExpected = false;
|
||||
return isLaneProjectedOffsetValue(value, expected, usesExpected) && usesExpected;
|
||||
}
|
||||
|
||||
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
|
||||
|
||||
@@ -89,9 +89,18 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
||||
auto offsetValue = llvm::dyn_cast<Value>(offset);
|
||||
return offsetValue == laneArg;
|
||||
static FailureOr<int64_t>
|
||||
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
|
||||
|
||||
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
|
||||
const DenseMap<Value, int64_t>& bindings,
|
||||
std::optional<uint32_t> lane,
|
||||
Value laneArg);
|
||||
|
||||
bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) {
|
||||
DenseMap<Value, int64_t> bindings;
|
||||
return succeeded(evaluateIndexLike(offset, bindings, /*lane=*/0, laneArg))
|
||||
&& succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg));
|
||||
}
|
||||
|
||||
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||
@@ -110,7 +119,7 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
|
||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||
if (!extract || extract.getSource() != *inputArg)
|
||||
return std::nullopt;
|
||||
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||
if (extract.getMixedOffsets().empty() || !isProjectedBatchOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||
return std::nullopt;
|
||||
|
||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||
|
||||
Reference in New Issue
Block a user