huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -90,6 +90,7 @@ static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Folds constant linalg fills inside cores into private globals plus device copies.
|
||||
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -249,6 +250,7 @@ static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, M
|
||||
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
|
||||
// Folds transposes of constant globals so weight-only transposes stay host-side.
|
||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -304,11 +306,9 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
bool isAlwaysWeight = !transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
@@ -330,6 +330,7 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
}
|
||||
};
|
||||
|
||||
// Collapses fill-and-copy allocation chains into one folded constant global.
|
||||
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -367,9 +368,8 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
|
||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
|
||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
});
|
||||
return llvm::all_of(castOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
})) {
|
||||
allLiveUsersAreCoreOps = false;
|
||||
}
|
||||
@@ -417,6 +417,7 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts host copies from dense globals into direct folded globals.
|
||||
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -431,37 +432,14 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
@@ -477,7 +455,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
@@ -494,6 +472,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts PIM copies from dense globals into direct folded globals before codegen.
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -511,37 +490,14 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
||||
if (failed(maybeFoldedAttr))
|
||||
return failure();
|
||||
foldedAttr = *maybeFoldedAttr;
|
||||
}
|
||||
else {
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
@@ -557,7 +513,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
@@ -577,13 +533,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(
|
||||
patterns.getContext());
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user