fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -72,12 +72,7 @@ delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64
|
|||||||
indices.reserve(shape.size());
|
indices.reserve(shape.size());
|
||||||
|
|
||||||
Value remaining = linearIndex;
|
Value remaining = linearIndex;
|
||||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||||
if (stride == 1) {
|
|
||||||
indices.push_back(remaining);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
||||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||||
indices.push_back(index);
|
indices.push_back(index);
|
||||||
@@ -155,6 +150,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
int64_t dstOffset,
|
int64_t dstOffset,
|
||||||
int64_t srcOffset,
|
int64_t srcOffset,
|
||||||
int64_t size,
|
int64_t size,
|
||||||
|
bool allowLoopRewrite,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
CreateCopyOp createCopyOp) {
|
CreateCopyOp createCopyOp) {
|
||||||
auto srcSubview = getStaticSubviewInfo(src);
|
auto srcSubview = getStaticSubviewInfo(src);
|
||||||
@@ -196,7 +192,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
|
||||||
if (numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||||
|
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
||||||
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
||||||
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
||||||
@@ -244,6 +241,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
|||||||
copyOp.getTargetOffset(),
|
copyOp.getTargetOffset(),
|
||||||
copyOp.getSourceOffset(),
|
copyOp.getSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
@@ -276,6 +274,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
|||||||
copyOp.getDeviceTargetOffset(),
|
copyOp.getDeviceTargetOffset(),
|
||||||
copyOp.getHostSourceOffset(),
|
copyOp.getHostSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
@@ -308,6 +307,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
|||||||
copyOp.getHostTargetOffset(),
|
copyOp.getHostTargetOffset(),
|
||||||
copyOp.getDeviceSourceOffset(),
|
copyOp.getDeviceSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/false,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
|||||||
@@ -67,6 +67,14 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||||
|
if (failed(resolvedAddress))
|
||||||
|
return false;
|
||||||
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
static bool isConstantGlobalView(Value value) {
|
static bool isConstantGlobalView(Value value) {
|
||||||
while (true) {
|
while (true) {
|
||||||
Operation* defOp = value.getDefiningOp();
|
Operation* defOp = value.getDefiningOp();
|
||||||
@@ -260,7 +268,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||||
if (!isCodegenAddressableValue(operand)) {
|
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||||
op.emitOpError() << "host operand #" << operandIndex
|
op.emitOpError() << "host operand #" << operandIndex
|
||||||
<< " is not backed by contiguous addressable storage";
|
<< " is not backed by contiguous addressable storage";
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
|
|||||||
Reference in New Issue
Block a user