merge remote changes
This commit is contained in:
@@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(coreOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
@@ -258,9 +257,18 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Look through an optional pim.memcp_hd to find the source get_global.
|
||||
// This occurs when the constant was staged into device memory before transposing.
|
||||
pim::PimMemCopyHostToDevOp memcpHd;
|
||||
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal)
|
||||
return failure();
|
||||
if (!sourceGetGlobal) {
|
||||
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
|
||||
if (!memcpHd)
|
||||
return failure();
|
||||
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
@@ -298,13 +306,26 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
|
||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||
|
||||
if (memcpHd && memcpHd.use_empty()) {
|
||||
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
|
||||
rewriter.eraseOp(memcpHd);
|
||||
if (deviceAllocOp && deviceAllocOp->use_empty())
|
||||
rewriter.eraseOp(deviceAllocOp);
|
||||
}
|
||||
if (outputAllocOp && outputAllocOp->use_empty())
|
||||
rewriter.eraseOp(outputAllocOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -341,18 +362,25 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<pim::PimCoreOp>(user))
|
||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
})) {
|
||||
allLiveUsersAreCoreOps = false;
|
||||
}
|
||||
|
||||
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
||||
return isa<linalg::MapOp,
|
||||
memref::SubViewOp,
|
||||
memref::DeallocOp,
|
||||
memref::CastOp,
|
||||
pim::PimCoreOp,
|
||||
pim::PimCoreBatchOp>(user);
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
@@ -389,6 +417,83 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -443,7 +548,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp>(user))
|
||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
@@ -473,7 +578,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user