huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled

remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
NiccoloN
2026-05-12 10:35:44 +02:00
parent feaff820e1
commit 909c4acfdd
84 changed files with 4048 additions and 3310 deletions
@@ -90,6 +90,7 @@ static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
return attr;
}
// Folds constant linalg fills inside cores into private globals plus device copies.
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
@@ -249,6 +250,7 @@ static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, M
return DenseElementsAttr::get(resultTensorType, resultValues);
}
// Folds transposes of constant globals so weight-only transposes stay host-side.
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
@@ -304,11 +306,9 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight =
!transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
});
bool isAlwaysWeight = !transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(),
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
@@ -330,6 +330,7 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
}
};
// Collapses fill-and-copy allocation chains into one folded constant global.
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@@ -367,9 +368,8 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
}
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
});
return llvm::all_of(castOp->getUsers(),
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
})) {
allLiveUsersAreCoreOps = false;
}
@@ -417,6 +417,7 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
}
};
// Converts host copies from dense globals into direct folded globals.
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
@@ -431,37 +432,14 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
if (!allocType || !allocType.hasStaticShape())
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(foldedAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
@@ -477,7 +455,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
@@ -494,6 +472,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
}
};
// Converts PIM copies from dense globals into direct folded globals before codegen.
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
@@ -511,37 +490,14 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
return failure();
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(foldedAttr))
return failure();
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
if (failed(maybeFoldedAttr))
return failure();
foldedAttr = *maybeFoldedAttr;
}
else {
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
@@ -557,7 +513,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
@@ -577,13 +533,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
} // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns
.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantHostCopyPattern,
FoldConstantMemCpPattern>(
patterns.getContext());
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantHostCopyPattern,
FoldConstantMemCpPattern>(patterns.getContext());
}
} // namespace onnx_mlir