Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
This commit is contained in:
@@ -1,92 +1,210 @@
|
|||||||
- Always read the full README.md before doing anything.
|
* Always read the full README.md before doing anything
|
||||||
- Build commands:
|
* Build commands:
|
||||||
- `cmake --build ./build_release`
|
* `cmake --build ./build_release`
|
||||||
- `cmake --build ./build_debug`
|
* `cmake --build ./build_debug`
|
||||||
- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache.
|
* Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache
|
||||||
- Always tries the release version build first and ask before building with the debug version
|
* Always try the release build first before building with the debug version
|
||||||
|
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
|
||||||
|
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
|
||||||
|
|
||||||
|
# Core engineering philosophy
|
||||||
|
|
||||||
|
* Clean architecture matters as much as making the immediate test pass
|
||||||
|
* Prefer fixes that preserve clear ownership boundaries, explicit invariants, and simple dataflow
|
||||||
|
* Do not stack compensating fixes on top of earlier mistakes. If the current approach is becoming messy, stop and explain why
|
||||||
|
* A correct fix should usually make the responsible producer, resolver, verifier, or lowering own the behavior directly
|
||||||
|
* Avoid late repair passes, defensive cleanup, or broad rewrites when a cleaner owner-side fix is possible
|
||||||
|
* Do not hide an upstream modeling bug by normalizing it later in the pipeline. Fix the producer when the producer owns the invariant
|
||||||
|
* Prefer patterns/rewrites for local IR canonicalization. Use module walks only when pass-level structural analysis genuinely requires them
|
||||||
|
* Prefer compact, structured designs over long case-by-case implementations
|
||||||
|
|
||||||
|
# Think before coding
|
||||||
|
|
||||||
|
* State assumptions explicitly before implementing when they affect the design
|
||||||
|
* If multiple interpretations exist, present them instead of silently choosing one
|
||||||
|
* If a simpler approach exists, say so and prefer it unless there is a clear reason not to
|
||||||
|
* If something is unclear, stop, name what is confusing, and ask
|
||||||
|
* If the requested or obvious approach would make the architecture worse, push back and propose a cleaner alternative
|
||||||
|
|
||||||
# Code changes
|
# Code changes
|
||||||
|
|
||||||
- Keep changes minimal and localized to the relevant parts of the code.
|
* Keep changes minimal and localized to the relevant parts of the code
|
||||||
- Preserve the existing naming conventions and coding style used in the surrounding code.
|
* Preserve the existing naming conventions and coding style used in the surrounding code
|
||||||
- Keep code easy to read, well organized, and suitable for future extensibility. A function must not be longer than
|
* Keep code easy to read, well organized, and suitable for future extensibility
|
||||||
200/250 lines for readability and cognitive complexity.
|
* A function must not exceed roughly 200/250 lines. If a change pushes a function beyond that, extract focused helpers
|
||||||
- Prefer clear naming and structure over comments. Add comments only when they materially improve clarity.
|
* Prefer clear naming and structure over comments. Add comments only when they materially improve clarity
|
||||||
- Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change.
|
* Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change
|
||||||
|
* Avoid duplicate ad-hoc logic. If the same concept appears in multiple places, consider whether it deserves a shared helper/API
|
||||||
|
* When adding a helper or API, ask:
|
||||||
|
* Could this be useful to another component now
|
||||||
|
* Is another component already implementing the same idea differently
|
||||||
|
* Is this likely to be needed by a future adjacent component
|
||||||
|
* What is the narrowest useful abstraction
|
||||||
|
* What is the correct ownership level for this API
|
||||||
|
* If a shared API is justified, place it at the lowest clean layer that can be used by all relevant consumers without creating dependency cycles or leaking policy across layers
|
||||||
|
* If an existing component should use a newly introduced shared API, refactor that component in the same patch when doing so is directly related and reduces duplication
|
||||||
|
* Do not create broad frameworks just because a helper might someday be useful. Shared APIs should encode a real reusable concept, not speculative generality
|
||||||
|
* If the reusable abstraction is plausible but not clearly needed yet, keep the code local and mention the possible future extraction separately
|
||||||
|
|
||||||
|
# Avoid case-listing designs
|
||||||
|
|
||||||
|
* Avoid solving problems with large chains of `if`/`else`, switches, or repeated special cases that enumerate every possible situation
|
||||||
|
* Long case listings tend to overfit the current tests, grow the codebase, and hide the underlying abstraction
|
||||||
|
* When you see a growing list of special cases, stop and look for the shared concept, data model, interface, or normalization step that would make the cases collapse
|
||||||
|
* Prefer table-driven logic, traits/interfaces, small reusable predicates, structured dispatch, or producer-side normalization when they express the invariant more directly
|
||||||
|
* A few explicit cases are fine when the domain is genuinely small and closed
|
||||||
|
* If the list is likely to grow, refactor toward a cleaner and more compact design instead of adding another branch
|
||||||
|
* When keeping a case list is the pragmatic choice, explain why the domain is closed or why a broader abstraction would be premature
|
||||||
|
|
||||||
|
# Ownership and invariants
|
||||||
|
|
||||||
|
Before implementing, identify the owner of the behavior:
|
||||||
|
|
||||||
|
* A producer should emit IR/data that satisfies the contract of the next stage
|
||||||
|
* A lowering should make representation changes explicit and semantically correct
|
||||||
|
* A resolver should resolve existing structure without silently changing semantics
|
||||||
|
* A verifier should reject invalid states with bounded, actionable diagnostics
|
||||||
|
* Codegen should assume verified invariants and fail clearly if they are violated
|
||||||
|
|
||||||
|
When fixing a bug:
|
||||||
|
|
||||||
|
* State the invariant that was violated
|
||||||
|
* State which component should own that invariant
|
||||||
|
* Fix that component directly
|
||||||
|
* Avoid fixes that merely mask the violation later in the pipeline
|
||||||
|
* Add or preserve verification if the invariant is important enough to regress
|
||||||
|
|
||||||
|
# Refactor and API policy
|
||||||
|
|
||||||
|
You may propose or implement a refactor when:
|
||||||
|
|
||||||
|
* the local fix would duplicate logic
|
||||||
|
* the local fix would violate a layer boundary
|
||||||
|
* the bug exists because responsibility is assigned to the wrong component
|
||||||
|
* multiple components already implement ad-hoc variants of the same concept
|
||||||
|
* a shared helper/API would make the code smaller, clearer, and easier to maintain
|
||||||
|
* existing callers can be migrated cleanly without broad churn
|
||||||
|
* the current implementation is turning into a long list of special cases instead of a structured solution
|
||||||
|
|
||||||
|
When proposing or implementing a refactor:
|
||||||
|
|
||||||
|
* Explain what responsibility is being moved or shared
|
||||||
|
* Justify why the new location is the right ownership level
|
||||||
|
* Keep the API narrow and named after the concept or invariant it represents
|
||||||
|
* Migrate directly related existing users when that improves compactness and consistency
|
||||||
|
* Separate changes required for correctness from optional cleanup
|
||||||
|
* Avoid unrelated renames, formatting changes, or module moves
|
||||||
|
* Do not expand a justified refactor beyond directly related callers
|
||||||
|
|
||||||
|
Do not refactor when:
|
||||||
|
|
||||||
|
* the issue is truly local and a local fix is clearer
|
||||||
|
* the abstraction would have only one user and no clear adjacent use
|
||||||
|
* the abstraction would mix policies from different layers
|
||||||
|
* the refactor would affect unrelated behavior
|
||||||
|
* the refactor is mainly aesthetic
|
||||||
|
|
||||||
# Working style
|
# Working style
|
||||||
|
|
||||||
- Infer style and conventions from the existing code before introducing new patterns.
|
* Infer style and conventions from the existing code before introducing new patterns
|
||||||
- When several implementation options are possible, prefer the simplest one that fits the current architecture and
|
* When several implementation options are possible, prefer the simplest one that fits the current architecture and minimizes churn
|
||||||
minimizes churn.
|
* Push back when the requested or obvious fix would make the architecture worse
|
||||||
- Avoid broad refactors unless I explicitly ask for them.
|
* If a cleaner fix requires a small refactor or shared helper/API, propose it explicitly and justify it
|
||||||
|
* Avoid broad refactors unless explicitly requested or clearly necessary for correctness and maintainability
|
||||||
|
* When tests fail, bucket failures by likely root cause and separate patch-related failures from pre-existing or out-of-scope failures
|
||||||
|
|
||||||
# Responses
|
# Simplicity first
|
||||||
|
|
||||||
- When showing code in chat, make it easy to copy-paste into the codebase.
|
* Minimum code that solves the problem cleanly. Nothing speculative
|
||||||
- Keep outputs focused on the changed parts.
|
* No features beyond what was asked
|
||||||
- At the end of the response, briefly list any bad practices, mistakes, or cleaner alternatives you noticed, separate
|
* No error handling for impossible scenarios
|
||||||
from the main solution.
|
* If you write 200 lines and it could be 50, rewrite it
|
||||||
|
* Ask: “Would a senior engineer say this is overcomplicated?” If yes, simplify
|
||||||
|
* Prefer direct, explicit code over generic machinery unless the generic machinery clearly reduces duplication and preserves boundaries
|
||||||
|
|
||||||
# Guidelines
|
# Fallbacks and defaults
|
||||||
|
|
||||||
## 1. Think Before Coding
|
* Avoid silent fallback behavior when the semantic category is unknown
|
||||||
|
* Do not treat “unknown” as “safe” unless the codebase already defines that convention
|
||||||
|
* If a value cannot be classified, either preserve the existing behavior deliberately or fail with a clear diagnostic
|
||||||
|
* When adding a fallback, state why it is semantically valid and what invariant makes it safe
|
||||||
|
|
||||||
**Don't assume. Don't hide confusion. Surface tradeoffs.**
|
# Surgical changes
|
||||||
|
|
||||||
Before implementing:
|
* Touch only what you must
|
||||||
|
* Clean up only the mess introduced by your own change
|
||||||
|
* Do not “improve” adjacent code, comments, or formatting
|
||||||
|
* Match existing style, even if you would personally do it differently
|
||||||
|
* If you notice unrelated dead code, bad abstractions, or fragile design, mention it separately. Do not delete or rewrite it unless asked
|
||||||
|
* When your changes create orphans, remove imports, variables, functions, or files made unused by your change
|
||||||
|
* Every changed line should trace directly to the requested fix, a required cleanup, or a justified reuse/refactor decision
|
||||||
|
|
||||||
- State your assumptions explicitly. If uncertain, ask.
|
# Diagnostics and verification
|
||||||
- If multiple interpretations exist, present them - don't pick silently.
|
|
||||||
- If a simpler approach exists, say so. Push back when warranted.
|
|
||||||
- If something is unclear, stop. Name what's confusing. Ask.
|
|
||||||
|
|
||||||
## 2. Simplicity First
|
* Use existing bounded diagnostic mechanisms for pass-level verification or codegen failures
|
||||||
|
* Do not emit unbounded repeated diagnostics from loops or parallel workers
|
||||||
|
* Diagnostics should identify the violated invariant and the relevant value/op when useful
|
||||||
|
* Verifiers should reject invalid states, not repair them
|
||||||
|
* Codegen should not compensate for invalid IR/data unless codegen is the owner of that invariant
|
||||||
|
* Do not make failing tests pass by weakening verifiers, assertions, or diagnostics unless the check itself is proven wrong
|
||||||
|
* If a check is too strict, explain the valid case it rejects and update the invariant accordingly
|
||||||
|
* Prefer fixing invalid IR/data producers over relaxing consumers
|
||||||
|
* If adding diagnostics only for debugging, remove them or cap them before finalizing
|
||||||
|
|
||||||
**Minimum code that solves the problem. Nothing speculative.**
|
# Temporary debugging code
|
||||||
|
|
||||||
- No features beyond what was asked.
|
* Temporary diagnostics, dumps, assertions, and debug-only helpers must be removed or intentionally converted into bounded permanent diagnostics before finalizing
|
||||||
- No error handling for impossible scenarios.
|
* If debug instrumentation remains, explain why it is useful as permanent infrastructure
|
||||||
- If you write 200 lines and it could be 50, rewrite it.
|
* Do not leave noisy validation output behind
|
||||||
|
|
||||||
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
|
# Performance awareness
|
||||||
|
|
||||||
## 3. Surgical Changes
|
* Avoid algorithmic regressions in compiler passes, especially repeated full-module walks, repeated expensive analyses, or per-op recomputation inside nested loops
|
||||||
|
* If a change adds a walk, cache, analysis, or structural traversal, justify why it is needed
|
||||||
|
* For hot paths, prefer preserving existing asymptotic behavior unless a better structure is part of the requested change
|
||||||
|
* If performance may change, mention the expected impact and suggest a targeted timing check
|
||||||
|
|
||||||
**Touch only what you must. Clean up only your own mess.**
|
# Goal-driven execution
|
||||||
|
|
||||||
When editing existing code:
|
|
||||||
|
|
||||||
- Don't "improve" adjacent code, comments, or formatting.
|
|
||||||
- Don't refactor things that aren't broken.
|
|
||||||
- Match existing style, even if you'd do it differently.
|
|
||||||
- If you notice unrelated dead code, mention it - don't delete it.
|
|
||||||
|
|
||||||
When your changes create orphans:
|
|
||||||
|
|
||||||
- Remove imports/variables/functions that YOUR changes made unused.
|
|
||||||
- Don't remove pre-existing dead code unless asked, but mention it.
|
|
||||||
|
|
||||||
The test: Every changed line should trace directly to the user's request.
|
|
||||||
|
|
||||||
## 4. Goal-Driven Execution
|
|
||||||
|
|
||||||
**Define success criteria. Loop until verified.**
|
|
||||||
|
|
||||||
Transform tasks into verifiable goals:
|
|
||||||
|
|
||||||
- "Add validation" → "Write tests for invalid inputs, then make them pass"
|
|
||||||
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
|
|
||||||
- "Refactor X" → "Ensure tests pass before and after"
|
|
||||||
|
|
||||||
For multi-step tasks, state a brief plan:
|
For multi-step tasks, state a brief plan:
|
||||||
|
|
||||||
```
|
|
||||||
1. [Step] → verify: [check]
|
1. [Step] → verify: [check]
|
||||||
2. [Step] → verify: [check]
|
2. [Step] → verify: [check]
|
||||||
3. [Step] → verify: [check]
|
3. [Step] → verify: [check]
|
||||||
```
|
|
||||||
|
|
||||||
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
|
Define success criteria before implementing:
|
||||||
|
|
||||||
---
|
* For bug fixes, success means reproducing or identifying the failure, fixing the responsible owner, and verifying the targeted case
|
||||||
|
* For refactors, success means preserving behavior while making ownership, reuse, or structure cleaner
|
||||||
|
* For validation changes, success means checking both valid and invalid cases when applicable
|
||||||
|
|
||||||
|
Transform tasks into verifiable goals:
|
||||||
|
|
||||||
|
* “Fix the bug” → identify the invariant, reproduce the failure, fix the owner, verify the targeted case
|
||||||
|
* “Add validation” → write or identify tests for invalid inputs, then make them pass/fail as expected
|
||||||
|
* “Refactor X” → preserve behavior before and after, then run relevant tests
|
||||||
|
|
||||||
|
# Final self-review
|
||||||
|
|
||||||
|
Before reporting completion, check:
|
||||||
|
|
||||||
|
* Did I fix the owner of the invariant rather than masking the issue downstream
|
||||||
|
* Did I avoid broad case lists and ad-hoc special handling
|
||||||
|
* Did I introduce a helper/API only at the right ownership level
|
||||||
|
* Did I migrate directly related duplicate logic when doing so improves compactness
|
||||||
|
* Did I avoid weakening verifiers or assertions unnecessarily
|
||||||
|
* Did I remove temporary debugging code or make it bounded and intentional
|
||||||
|
* Did I avoid unrelated formatting, renames, or cleanup
|
||||||
|
* Did I consider performance impact for added walks, analyses, caches, or repeated computations
|
||||||
|
* Did I run the required build/test commands
|
||||||
|
* Did I clearly report remaining failures or risks
|
||||||
|
|
||||||
|
When reporting back:
|
||||||
|
|
||||||
|
* Say what changed
|
||||||
|
* Say what was verified
|
||||||
|
* Say what remains
|
||||||
|
* When showing code in chat, make it easy to copy-paste into the codebase
|
||||||
|
* Keep outputs focused on the changed parts
|
||||||
|
* List bad practices, fragile assumptions, or cleaner alternatives separately
|
||||||
|
* If a change is intentionally pragmatic rather than architecturally ideal, say so and explain the tradeoff
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -690,11 +690,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
Value b = gemmOpAdaptor.getB();
|
Value b = gemmOpAdaptor.getB();
|
||||||
Value c = gemmOpAdaptor.getC();
|
Value c = gemmOpAdaptor.getC();
|
||||||
|
|
||||||
if (gemmOpAdaptor.getTransA()) {
|
|
||||||
gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto aType = dyn_cast<RankedTensorType>(a.getType());
|
auto aType = dyn_cast<RankedTensorType>(a.getType());
|
||||||
auto bType = dyn_cast<RankedTensorType>(b.getType());
|
auto bType = dyn_cast<RankedTensorType>(b.getType());
|
||||||
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
|
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
|
||||||
@@ -725,9 +720,12 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = outType.getDimSize(0);
|
if (gemmOpAdaptor.getTransA()) {
|
||||||
const int64_t numOutCols = outType.getDimSize(1);
|
auto aShape = aType.getShape();
|
||||||
const int64_t reductionSize = aType.getDimSize(1);
|
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding());
|
||||||
|
a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||||
|
aType = transposedType;
|
||||||
|
}
|
||||||
|
|
||||||
if (gemmOpAdaptor.getTransB()) {
|
if (gemmOpAdaptor.getTransB()) {
|
||||||
auto bShape = bType.getShape();
|
auto bShape = bType.getShape();
|
||||||
@@ -736,6 +734,10 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
bType = transposedType;
|
bType = transposedType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t numOutRows = outType.getDimSize(0);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
|
|
||||||
if (!isCompileTimeComputable(b)) {
|
if (!isCompileTimeComputable(b)) {
|
||||||
bool hasC = hasGemmBias(c);
|
bool hasC = hasGemmBias(c);
|
||||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||||
|
|||||||
@@ -22,13 +22,87 @@ namespace {
|
|||||||
|
|
||||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||||
ArrayRef<int64_t> rhsBatchShape) {
|
ArrayRef<int64_t> rhsBatchShape) {
|
||||||
if (lhsBatchShape.empty())
|
const int64_t resultRank = std::max<int64_t>(lhsBatchShape.size(), rhsBatchShape.size());
|
||||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
SmallVector<int64_t> resultShape(resultRank, 1);
|
||||||
if (rhsBatchShape.empty())
|
for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1;
|
||||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
resultIndex >= 0;
|
||||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
--resultIndex, --lhsIndex, --rhsIndex) {
|
||||||
return failure();
|
const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1;
|
||||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1;
|
||||||
|
if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1)
|
||||||
|
return failure();
|
||||||
|
resultShape[resultIndex] = std::max(lhsDim, rhsDim);
|
||||||
|
}
|
||||||
|
return resultShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex,
|
||||||
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
|
ArrayRef<int64_t> outputBatchShape) {
|
||||||
|
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
|
||||||
|
return 0;
|
||||||
|
if (llvm::equal(sourceBatchShape, outputBatchShape))
|
||||||
|
return outputBatchIndex;
|
||||||
|
|
||||||
|
SmallVector<int64_t> outputStrides = computeRowMajorStrides(outputBatchShape);
|
||||||
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
|
||||||
|
int64_t sourceFlatIndex = 0;
|
||||||
|
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
|
||||||
|
if (sourceBatchShape[sourceDimIndex] == 1)
|
||||||
|
continue;
|
||||||
|
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
|
||||||
|
const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex];
|
||||||
|
const int64_t outputDimIndexValue = outputDimStride == 1
|
||||||
|
? outputBatchIndex % outputBatchShape[outputDimIndex]
|
||||||
|
: (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex];
|
||||||
|
sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex];
|
||||||
|
}
|
||||||
|
return sourceFlatIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value computeFlatBatchIndexCoordinate(
|
||||||
|
Value flatBatchIndex, ArrayRef<int64_t> batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) {
|
||||||
|
if (batchShape[dimIndex] == 1)
|
||||||
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
|
||||||
|
const int64_t dimStride = dimIndex + 1 == static_cast<int64_t>(batchShape.size())
|
||||||
|
? 1
|
||||||
|
: getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1));
|
||||||
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||||
|
Value dimCoordinate = flatBatchIndex;
|
||||||
|
if (dimStride != 1)
|
||||||
|
dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp);
|
||||||
|
return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex,
|
||||||
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
|
||||||
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
if (llvm::equal(sourceBatchShape, outputBatchShape))
|
||||||
|
return outputBatchIndex;
|
||||||
|
|
||||||
|
Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
|
||||||
|
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
|
||||||
|
if (sourceBatchShape[sourceDimIndex] == 1)
|
||||||
|
continue;
|
||||||
|
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
|
||||||
|
Value outputCoordinate =
|
||||||
|
computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc);
|
||||||
|
Value contribution = sourceStrides[sourceDimIndex] == 1
|
||||||
|
? outputCoordinate
|
||||||
|
: affineMulConst(rewriter,
|
||||||
|
loc,
|
||||||
|
outputCoordinate,
|
||||||
|
sourceStrides[sourceDimIndex],
|
||||||
|
rewriter.getInsertionBlock()->getParentOp());
|
||||||
|
sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution);
|
||||||
|
}
|
||||||
|
return sourceBatchIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
@@ -67,6 +141,52 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
|||||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto buildExpanded = [&](Value input) -> Value {
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
resultType,
|
||||||
|
input,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> removedAxes) {
|
||||||
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
|
ReassociationIndices currentGroup;
|
||||||
|
for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) {
|
||||||
|
currentGroup.push_back(axis);
|
||||||
|
if (!removeAxis) {
|
||||||
|
reassociation.push_back(currentGroup);
|
||||||
|
currentGroup.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!currentGroup.empty()) {
|
||||||
|
if (reassociation.empty())
|
||||||
|
reassociation.push_back(std::move(currentGroup));
|
||||||
|
else
|
||||||
|
reassociation.back().append(currentGroup.begin(), currentGroup.end());
|
||||||
|
}
|
||||||
|
return reassociation;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value squeezeUnitDims(
|
||||||
|
Value value, RankedTensorType resultType, ArrayRef<bool> removedAxes, PatternRewriter& rewriter, Location loc) {
|
||||||
|
if (cast<RankedTensorType>(value.getType()) == resultType)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation =
|
||||||
|
resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(removedAxes);
|
||||||
|
auto buildCollapsed = [&](Value input) -> Value {
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult();
|
||||||
|
};
|
||||||
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed);
|
||||||
|
}
|
||||||
|
|
||||||
static Value ensureBatchedTensor(
|
static Value ensureBatchedTensor(
|
||||||
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
@@ -171,8 +291,11 @@ static Value createPaddedBatchedInputCompute(Value input,
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<Value> materializePaddedBatchedWeight(
|
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
|
||||||
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
|
ArrayRef<int64_t> targetBatchShape,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||||
if (sourceType == resultType)
|
if (sourceType == resultType)
|
||||||
return value;
|
return value;
|
||||||
@@ -183,13 +306,15 @@ static FailureOr<Value> materializePaddedBatchedWeight(
|
|||||||
|
|
||||||
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
||||||
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
||||||
|
const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape);
|
||||||
const int64_t targetRows = resultType.getDimSize(1);
|
const int64_t targetRows = resultType.getDimSize(1);
|
||||||
const int64_t targetCols = resultType.getDimSize(2);
|
const int64_t targetCols = resultType.getDimSize(2);
|
||||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||||
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
||||||
|
|
||||||
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
||||||
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
|
const int64_t sourceBatchIdx =
|
||||||
|
sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape);
|
||||||
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
||||||
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
||||||
for (int64_t row = 0; row < sourceRows; ++row)
|
for (int64_t row = 0; row < sourceRows; ++row)
|
||||||
@@ -202,16 +327,18 @@ static FailureOr<Value> materializePaddedBatchedWeight(
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value extractBatchedATile(Value a,
|
static Value extractBatchedATile(Value a,
|
||||||
int64_t sourceBatchCount,
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
Value batch,
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
|
Value outputBatchIndex,
|
||||||
Value row,
|
Value row,
|
||||||
Value kOffset,
|
Value kOffset,
|
||||||
RankedTensorType aTileType,
|
RankedTensorType aTileType,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
||||||
SmallVector<OpFoldResult> offsets {
|
Value sourceBatchIndex =
|
||||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
||||||
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, kOffset};
|
||||||
SmallVector<OpFoldResult> sizes {
|
SmallVector<OpFoldResult> sizes {
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
||||||
auto slice =
|
auto slice =
|
||||||
@@ -227,8 +354,9 @@ static Value extractBatchedATile(Value a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value extractBatchedBTile(Value b,
|
static Value extractBatchedBTile(Value b,
|
||||||
int64_t sourceBatchCount,
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
Value batch,
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
|
Value outputBatchIndex,
|
||||||
Value kOffset,
|
Value kOffset,
|
||||||
Value hOffset,
|
Value hOffset,
|
||||||
RankedTensorType bTileType,
|
RankedTensorType bTileType,
|
||||||
@@ -236,8 +364,9 @@ static Value extractBatchedBTile(Value b,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
auto bSliceType =
|
auto bSliceType =
|
||||||
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
||||||
SmallVector<OpFoldResult> offsets {
|
Value sourceBatchIndex =
|
||||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
||||||
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset};
|
||||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
||||||
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
||||||
@@ -262,9 +391,10 @@ static Value getBatchLaneIndex(
|
|||||||
static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
||||||
Value b,
|
Value b,
|
||||||
RankedTensorType aType,
|
RankedTensorType aType,
|
||||||
int64_t aBatchCount,
|
ArrayRef<int64_t> aBatchShape,
|
||||||
RankedTensorType bType,
|
RankedTensorType bType,
|
||||||
int64_t bBatchCount,
|
ArrayRef<int64_t> bBatchShape,
|
||||||
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
RankedTensorType partialPiecesType,
|
RankedTensorType partialPiecesType,
|
||||||
int64_t numOutRows,
|
int64_t numOutRows,
|
||||||
int64_t numKSlices,
|
int64_t numKSlices,
|
||||||
@@ -298,10 +428,10 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
|||||||
auto pieceType =
|
auto pieceType =
|
||||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||||
|
|
||||||
Value aTile =
|
Value aTile = extractBatchedATile(
|
||||||
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
|
args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc);
|
||||||
Value bTile =
|
Value bTile = extractBatchedBTile(
|
||||||
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||||
|
|
||||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
@@ -315,17 +445,17 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value extractDynamicBatchedBColumn(Value matrix,
|
static Value extractDynamicBatchedBColumn(Value matrix,
|
||||||
int64_t sourceBatchCount,
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
Value batch,
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
|
Value outputBatchIndex,
|
||||||
Value column,
|
Value column,
|
||||||
RankedTensorType vectorType,
|
RankedTensorType vectorType,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
Value sourceBatchIndex =
|
||||||
: OpFoldResult(batch),
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
||||||
rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column};
|
||||||
column};
|
|
||||||
SmallVector<OpFoldResult> sizes {
|
SmallVector<OpFoldResult> sizes {
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
@@ -350,17 +480,17 @@ static Value extractDynamicBatchedBColumn(Value matrix,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||||
int64_t sourceBatchCount,
|
ArrayRef<int64_t> sourceBatchShape,
|
||||||
Value batch,
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
|
Value outputBatchIndex,
|
||||||
Value row,
|
Value row,
|
||||||
RankedTensorType vectorType,
|
RankedTensorType vectorType,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
Value sourceBatchIndex =
|
||||||
: OpFoldResult(batch),
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
||||||
row,
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)};
|
||||||
rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes {
|
SmallVector<OpFoldResult> sizes {
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||||
auto rowSlice =
|
auto rowSlice =
|
||||||
@@ -376,9 +506,10 @@ static Value extractDynamicBatchedRowVector(Value matrix,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||||
int64_t aBatchCount,
|
ArrayRef<int64_t> aBatchShape,
|
||||||
Value b,
|
Value b,
|
||||||
int64_t bBatchCount,
|
ArrayRef<int64_t> bBatchShape,
|
||||||
|
ArrayRef<int64_t> outputBatchShape,
|
||||||
RankedTensorType aType,
|
RankedTensorType aType,
|
||||||
RankedTensorType bType,
|
RankedTensorType bType,
|
||||||
RankedTensorType scalarPiecesType,
|
RankedTensorType scalarPiecesType,
|
||||||
@@ -406,10 +537,10 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
|||||||
|
|
||||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
Value aVector =
|
Value aVector = extractDynamicBatchedRowVector(
|
||||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc);
|
||||||
Value bVector =
|
Value bVector = extractDynamicBatchedBColumn(
|
||||||
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc);
|
||||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
@@ -629,11 +760,17 @@ static FailureOr<Value> createBatchedReductionCompute(Value partialPieces,
|
|||||||
return computeOp->getResult(0);
|
return computeOp->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MatMulShapeInfo {
|
struct NormalizedMatMulInfo {
|
||||||
RankedTensorType lhsType;
|
RankedTensorType lhsType;
|
||||||
RankedTensorType rhsType;
|
RankedTensorType rhsType;
|
||||||
RankedTensorType outType;
|
RankedTensorType outType;
|
||||||
SmallVector<int64_t> batchShape;
|
RankedTensorType normalizedLhsType;
|
||||||
|
RankedTensorType normalizedRhsType;
|
||||||
|
SmallVector<int64_t> lhsBatchShape;
|
||||||
|
SmallVector<int64_t> rhsBatchShape;
|
||||||
|
SmallVector<int64_t> outputBatchShape;
|
||||||
|
bool lhsWasVector;
|
||||||
|
bool rhsWasVector;
|
||||||
int64_t lhsBatch;
|
int64_t lhsBatch;
|
||||||
int64_t rhsBatch;
|
int64_t rhsBatch;
|
||||||
int64_t batch;
|
int64_t batch;
|
||||||
@@ -642,46 +779,170 @@ struct MatMulShapeInfo {
|
|||||||
int64_t n;
|
int64_t n;
|
||||||
};
|
};
|
||||||
|
|
||||||
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
struct MatMulLoweringPlan {
|
||||||
|
Value lhs;
|
||||||
|
Value rhs;
|
||||||
|
RankedTensorType lhsType;
|
||||||
|
RankedTensorType rhsType;
|
||||||
|
SmallVector<int64_t> lhsBatchShape;
|
||||||
|
SmallVector<int64_t> rhsBatchShape;
|
||||||
|
SmallVector<int64_t> outputBatchShape;
|
||||||
|
int64_t lhsBatch;
|
||||||
|
int64_t rhsBatch;
|
||||||
|
int64_t batch;
|
||||||
|
int64_t m;
|
||||||
|
int64_t k;
|
||||||
|
int64_t n;
|
||||||
|
bool transposedResult;
|
||||||
|
};
|
||||||
|
|
||||||
|
static SmallVector<int64_t> computeExpectedMatMulOutputShape(
|
||||||
|
ArrayRef<int64_t> batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) {
|
||||||
|
SmallVector<int64_t> shape(batchShape.begin(), batchShape.end());
|
||||||
|
if (lhsWasVector && rhsWasVector)
|
||||||
|
return shape;
|
||||||
|
if (lhsWasVector) {
|
||||||
|
shape.push_back(n);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
if (rhsWasVector) {
|
||||||
|
shape.push_back(m);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
shape.push_back(m);
|
||||||
|
shape.push_back(n);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<NormalizedMatMulInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
||||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
if (lhsType.getRank() < 1 || rhsType.getRank() < 1)
|
||||||
return failure();
|
return failure();
|
||||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
const bool lhsWasVector = lhsType.getRank() == 1;
|
||||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
const bool rhsWasVector = rhsType.getRank() == 1;
|
||||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
auto normalizedLhsType =
|
||||||
if (failed(batchShape))
|
lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding())
|
||||||
|
: lhsType;
|
||||||
|
auto normalizedRhsType =
|
||||||
|
rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding())
|
||||||
|
: rhsType;
|
||||||
|
|
||||||
|
SmallVector<int64_t> lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2);
|
||||||
|
SmallVector<int64_t> rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2);
|
||||||
|
auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||||
|
if (failed(outputBatchShape))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape);
|
||||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2);
|
||||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1);
|
||||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2);
|
||||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1);
|
||||||
if (k != rhsK)
|
if (k != rhsK)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (outType.getRank() == 2) {
|
if (SmallVector<int64_t>(outType.getShape().begin(), outType.getShape().end())
|
||||||
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
!= computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
else {
|
|
||||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
|
||||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
|
||||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
|
return NormalizedMatMulInfo {lhsType,
|
||||||
|
rhsType,
|
||||||
|
outType,
|
||||||
|
normalizedLhsType,
|
||||||
|
normalizedRhsType,
|
||||||
|
lhsBatchShape,
|
||||||
|
rhsBatchShape,
|
||||||
|
*outputBatchShape,
|
||||||
|
lhsWasVector,
|
||||||
|
rhsWasVector,
|
||||||
|
lhsBatch,
|
||||||
|
rhsBatch,
|
||||||
|
batch,
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
n};
|
||||||
|
}
|
||||||
|
|
||||||
|
static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs,
|
||||||
|
Value normalizedRhs,
|
||||||
|
const NormalizedMatMulInfo& info,
|
||||||
|
bool useTransposedForm,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
MatMulLoweringPlan plan {normalizedLhs,
|
||||||
|
normalizedRhs,
|
||||||
|
cast<RankedTensorType>(normalizedLhs.getType()),
|
||||||
|
cast<RankedTensorType>(normalizedRhs.getType()),
|
||||||
|
info.lhsBatchShape,
|
||||||
|
info.rhsBatchShape,
|
||||||
|
info.outputBatchShape,
|
||||||
|
info.lhsBatch,
|
||||||
|
info.rhsBatch,
|
||||||
|
info.batch,
|
||||||
|
info.m,
|
||||||
|
info.k,
|
||||||
|
info.n,
|
||||||
|
false};
|
||||||
|
if (!useTransposedForm)
|
||||||
|
return plan;
|
||||||
|
|
||||||
|
plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc);
|
||||||
|
plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc);
|
||||||
|
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
|
||||||
|
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
|
||||||
|
std::swap(plan.lhsBatchShape, plan.rhsBatchShape);
|
||||||
|
std::swap(plan.lhsBatch, plan.rhsBatch);
|
||||||
|
plan.m = info.n;
|
||||||
|
plan.n = info.m;
|
||||||
|
plan.transposedResult = true;
|
||||||
|
return plan;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value normalizeMatMulOperand(
|
||||||
|
Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) {
|
||||||
|
if (!wasVector)
|
||||||
|
return value;
|
||||||
|
return createMatrixFromVector(value, normalizedType, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value finalizeNormalizedMatMulResult(Value value,
|
||||||
|
RankedTensorType directOutType,
|
||||||
|
const NormalizedMatMulInfo& info,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
// The direct lowered result is always [flatBatch, normalizedM, normalizedN].
|
||||||
|
// Restore ONNX MatMul result rank by expanding right-aligned batch dimensions
|
||||||
|
// and removing the synthetic unit matrix axes introduced for vector operands.
|
||||||
|
Value result = value;
|
||||||
|
RankedTensorType currentType = directOutType;
|
||||||
|
if (info.outputBatchShape.size() > 1) {
|
||||||
|
SmallVector<int64_t> expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end());
|
||||||
|
expandedShape.push_back(info.m);
|
||||||
|
expandedShape.push_back(info.n);
|
||||||
|
auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding());
|
||||||
|
result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc);
|
||||||
|
currentType = expandedType;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<bool> removedAxes(currentType.getRank(), false);
|
||||||
|
if (info.outputBatchShape.empty())
|
||||||
|
removedAxes[0] = true;
|
||||||
|
if (info.lhsWasVector)
|
||||||
|
removedAxes[currentType.getRank() - 2] = true;
|
||||||
|
if (info.rhsWasVector)
|
||||||
|
removedAxes[currentType.getRank() - 1] = true;
|
||||||
|
return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
@@ -689,7 +950,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
|
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
@@ -742,61 +1003,56 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
if (failed(shapeInfo))
|
if (failed(shapeInfo))
|
||||||
return failure();
|
return failure();
|
||||||
if (shapeInfo->outType.getRank() == 2)
|
if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector
|
||||||
|
&& isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||||
|
|
||||||
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
Value lhs =
|
||||||
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc);
|
||||||
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
Value rhs =
|
||||||
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc);
|
||||||
int64_t gemmM = shapeInfo->m;
|
lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||||
int64_t gemmK = shapeInfo->k;
|
rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||||
int64_t gemmN = shapeInfo->n;
|
MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc);
|
||||||
if (useTransposedForm) {
|
|
||||||
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
|
||||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
|
||||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
|
||||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
|
||||||
gemmM = shapeInfo->n;
|
|
||||||
gemmN = shapeInfo->m;
|
|
||||||
}
|
|
||||||
|
|
||||||
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc);
|
||||||
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc);
|
||||||
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
|
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
|
||||||
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
|
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
|
||||||
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
|
auto directOutType = RankedTensorType::get(
|
||||||
|
{plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
|
||||||
|
|
||||||
if (isCompileTimeComputable(rhs)) {
|
if (isCompileTimeComputable(plan.rhs)) {
|
||||||
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
|
const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue());
|
||||||
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
|
const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue());
|
||||||
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||||
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||||
auto paddedLhsType = RankedTensorType::get(
|
auto paddedLhsType = RankedTensorType::get(
|
||||||
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
|
{plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding());
|
||||||
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
|
auto paddedRhsType = RankedTensorType::get(
|
||||||
rhsBatchedType.getElementType(),
|
{plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding());
|
||||||
rhsBatchedType.getEncoding());
|
|
||||||
auto paddedOutType =
|
auto paddedOutType =
|
||||||
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
|
RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType());
|
||||||
|
|
||||||
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
|
auto paddedRhs =
|
||||||
|
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
|
||||||
if (succeeded(paddedRhs)) {
|
if (succeeded(paddedRhs)) {
|
||||||
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
|
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||||
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
|
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
|
||||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
shapeInfo->outType.getElementType());
|
shapeInfo->outType.getElementType());
|
||||||
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
||||||
*paddedRhs,
|
*paddedRhs,
|
||||||
paddedLhsType,
|
paddedLhsType,
|
||||||
lhsBatchForGemm,
|
plan.lhsBatchShape,
|
||||||
paddedRhsType,
|
paddedRhsType,
|
||||||
rhsBatchForGemm,
|
plan.rhsBatchShape,
|
||||||
|
plan.outputBatchShape,
|
||||||
partialPiecesType,
|
partialPiecesType,
|
||||||
gemmM,
|
plan.m,
|
||||||
numKSlices,
|
numKSlices,
|
||||||
numOutHSlices,
|
numOutHSlices,
|
||||||
rewriter,
|
rewriter,
|
||||||
@@ -807,34 +1063,35 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
partialPiecesType,
|
partialPiecesType,
|
||||||
directOutType,
|
directOutType,
|
||||||
paddedOutType,
|
paddedOutType,
|
||||||
shapeInfo->batch,
|
plan.batch,
|
||||||
numKSlices,
|
numKSlices,
|
||||||
rewriter,
|
rewriter,
|
||||||
loc);
|
loc);
|
||||||
if (failed(result))
|
if (failed(result))
|
||||||
return failure();
|
return failure();
|
||||||
Value finalResult = *result;
|
Value finalResult = *result;
|
||||||
if (useTransposedForm) {
|
if (plan.transposedResult) {
|
||||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
|
||||||
shapeInfo->outType.getElementType(),
|
shapeInfo->outType.getElementType(),
|
||||||
shapeInfo->outType.getEncoding());
|
shapeInfo->outType.getEncoding());
|
||||||
finalResult =
|
finalResult =
|
||||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, finalResult);
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
|
const int64_t laneCount = plan.batch * plan.m * plan.n;
|
||||||
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
||||||
auto batchOp = createBatchedVvdmulBatch(lhs,
|
auto batchOp = createBatchedVvdmulBatch(plan.lhs,
|
||||||
lhsBatchForGemm,
|
plan.lhsBatchShape,
|
||||||
rhs,
|
plan.rhs,
|
||||||
rhsBatchForGemm,
|
plan.rhsBatchShape,
|
||||||
lhsBatchedType,
|
plan.outputBatchShape,
|
||||||
rhsBatchedType,
|
plan.lhsType,
|
||||||
|
plan.rhsType,
|
||||||
scalarPiecesType,
|
scalarPiecesType,
|
||||||
directOutType,
|
directOutType,
|
||||||
rewriter,
|
rewriter,
|
||||||
@@ -846,15 +1103,15 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (failed(result))
|
if (failed(result))
|
||||||
return failure();
|
return failure();
|
||||||
Value finalResult = *result;
|
Value finalResult = *result;
|
||||||
if (useTransposedForm) {
|
if (plan.transposedResult) {
|
||||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
|
||||||
shapeInfo->outType.getElementType(),
|
shapeInfo->outType.getElementType(),
|
||||||
shapeInfo->outType.getEncoding());
|
shapeInfo->outType.getEncoding());
|
||||||
finalResult =
|
finalResult =
|
||||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, finalResult);
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -238,14 +238,8 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
|||||||
ArrayRef<bool> reducedAxes,
|
ArrayRef<bool> reducedAxes,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (resultType.getRank() == 0) {
|
SmallVector<ReassociationIndices> reassociation =
|
||||||
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
|
resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(reducedAxes);
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
|
|
||||||
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
|
|
||||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
|
||||||
}
|
|
||||||
|
|
||||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
|
||||||
if (isCompileTimeComputable(keepdimsValue))
|
if (isCompileTimeComputable(keepdimsValue))
|
||||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -262,7 +262,7 @@ def conv_grouped_many_groups():
|
|||||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 2, 2])
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 2, 2])
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 2, 2])
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 2, 2])
|
||||||
W = numpy_helper.from_array(
|
W = numpy_helper.from_array(
|
||||||
np.random.default_rng(77).uniform(-1, 1, (1024, 64, 1, 1)).astype(np.float32), name="W")
|
np.random.default_rng(77).uniform(-1, 1, (1024, 16, 1, 1)).astype(np.float32), name="W")
|
||||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0], group=64)
|
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0], group=64)
|
||||||
graph = helper.make_graph([node], "conv_grouped_many_groups", [X], [Y], initializer=[W])
|
graph = helper.make_graph([node], "conv_grouped_many_groups", [X], [Y], initializer=[W])
|
||||||
@@ -779,28 +779,6 @@ def matmul_matrix_vector():
|
|||||||
save_model(model, "matmul/matrix_vector", "matmul_matrix_vector.onnx")
|
save_model(model, "matmul/matrix_vector", "matmul_matrix_vector.onnx")
|
||||||
|
|
||||||
|
|
||||||
def matmul_vector_vector_dot():
|
|
||||||
"""Vector-vector MatMul producing a scalar output."""
|
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1024])
|
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [])
|
|
||||||
B = numpy_helper.from_array(np.random.default_rng(97).uniform(-1, 1, (1024,)).astype(np.float32), name="B")
|
|
||||||
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
|
||||||
graph = helper.make_graph([node], "matmul_vector_vector_dot", [A], [Y], initializer=[B])
|
|
||||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
|
||||||
save_model(model, "matmul/vector_vector_dot", "matmul_vector_vector_dot.onnx")
|
|
||||||
|
|
||||||
|
|
||||||
def matmul_batched_4d_broadcast():
|
|
||||||
"""Batched 4D MatMul with broadcast across leading dimensions."""
|
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 1, 3, 4])
|
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 5, 3, 6])
|
|
||||||
B = numpy_helper.from_array(np.random.default_rng(98).uniform(-1, 1, (1, 5, 4, 6)).astype(np.float32), name="B")
|
|
||||||
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
|
||||||
graph = helper.make_graph([node], "matmul_batched_4d_broadcast", [A], [Y], initializer=[B])
|
|
||||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
|
||||||
save_model(model, "matmul/batched_4d_broadcast", "matmul_batched_4d_broadcast.onnx")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pooling tests
|
# Pooling tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -1560,17 +1538,6 @@ def add_channel_broadcast_1024():
|
|||||||
save_model(model, "add/channel_broadcast_1024", "add_channel_broadcast_1024.onnx")
|
save_model(model, "add/channel_broadcast_1024", "add_channel_broadcast_1024.onnx")
|
||||||
|
|
||||||
|
|
||||||
def add_scalar_runtime():
|
|
||||||
"""Elementwise Add with a runtime scalar RHS."""
|
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1])
|
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
node = helper.make_node("Add", ["A", "B"], ["Y"])
|
|
||||||
graph = helper.make_graph([node], "add_scalar_runtime", [A, B], [Y])
|
|
||||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
|
||||||
save_model(model, "add/scalar_runtime", "add_scalar_runtime.onnx")
|
|
||||||
|
|
||||||
|
|
||||||
def add_leading_dimension_broadcast():
|
def add_leading_dimension_broadcast():
|
||||||
"""Elementwise Add with trailing-dimension broadcasting."""
|
"""Elementwise Add with trailing-dimension broadcasting."""
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
||||||
@@ -1635,17 +1602,6 @@ def mul_channel_broadcast_1024():
|
|||||||
save_model(model, "mul/channel_broadcast_1024", "mul_channel_broadcast_1024.onnx")
|
save_model(model, "mul/channel_broadcast_1024", "mul_channel_broadcast_1024.onnx")
|
||||||
|
|
||||||
|
|
||||||
def mul_scalar_runtime():
|
|
||||||
"""Elementwise Mul with a runtime scalar RHS."""
|
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1])
|
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
node = helper.make_node("Mul", ["A", "B"], ["Y"])
|
|
||||||
graph = helper.make_graph([node], "mul_scalar_runtime", [A, B], [Y])
|
|
||||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
|
||||||
save_model(model, "mul/scalar_runtime", "mul_scalar_runtime.onnx")
|
|
||||||
|
|
||||||
|
|
||||||
def mul_leading_dimension_broadcast():
|
def mul_leading_dimension_broadcast():
|
||||||
"""Elementwise Mul with trailing-dimension broadcasting."""
|
"""Elementwise Mul with trailing-dimension broadcasting."""
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
||||||
@@ -1721,17 +1677,6 @@ def div_runtime_scalar_rhs():
|
|||||||
save_model(model, "div/runtime_scalar_rhs", "div_runtime_scalar_rhs.onnx")
|
save_model(model, "div/runtime_scalar_rhs", "div_runtime_scalar_rhs.onnx")
|
||||||
|
|
||||||
|
|
||||||
def div_runtime_scalar_lhs():
|
|
||||||
"""Elementwise Div with a scalar constant numerator."""
|
|
||||||
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
|
|
||||||
A = numpy_helper.from_array(np.asarray([[[[2.0]]]], dtype=np.float32), name="A")
|
|
||||||
node = helper.make_node("Div", ["A", "B"], ["Y"])
|
|
||||||
graph = helper.make_graph([node], "div_runtime_scalar_lhs", [B], [Y], initializer=[A])
|
|
||||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
|
||||||
save_model(model, "div/runtime_scalar_lhs", "div_runtime_scalar_lhs.onnx")
|
|
||||||
|
|
||||||
|
|
||||||
def div_leading_dimension_broadcast():
|
def div_leading_dimension_broadcast():
|
||||||
"""Elementwise Div with trailing-dimension broadcasting."""
|
"""Elementwise Div with trailing-dimension broadcasting."""
|
||||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
||||||
@@ -1812,8 +1757,6 @@ if __name__ == "__main__":
|
|||||||
matmul_huge_1024()
|
matmul_huge_1024()
|
||||||
matmul_vector_matrix()
|
matmul_vector_matrix()
|
||||||
matmul_matrix_vector()
|
matmul_matrix_vector()
|
||||||
matmul_vector_vector_dot()
|
|
||||||
matmul_batched_4d_broadcast()
|
|
||||||
|
|
||||||
print("\nGenerating Pooling tests:")
|
print("\nGenerating Pooling tests:")
|
||||||
maxpool_basic()
|
maxpool_basic()
|
||||||
@@ -1899,7 +1842,6 @@ if __name__ == "__main__":
|
|||||||
add_broadcast_row()
|
add_broadcast_row()
|
||||||
add_after_gemm()
|
add_after_gemm()
|
||||||
add_channel_broadcast_1024()
|
add_channel_broadcast_1024()
|
||||||
add_scalar_runtime()
|
|
||||||
add_leading_dimension_broadcast()
|
add_leading_dimension_broadcast()
|
||||||
|
|
||||||
print("\nGenerating Mul tests:")
|
print("\nGenerating Mul tests:")
|
||||||
@@ -1907,7 +1849,6 @@ if __name__ == "__main__":
|
|||||||
mul_scalar_constant()
|
mul_scalar_constant()
|
||||||
mul_after_conv()
|
mul_after_conv()
|
||||||
mul_channel_broadcast_1024()
|
mul_channel_broadcast_1024()
|
||||||
mul_scalar_runtime()
|
|
||||||
mul_leading_dimension_broadcast()
|
mul_leading_dimension_broadcast()
|
||||||
|
|
||||||
print("\nGenerating Div tests:")
|
print("\nGenerating Div tests:")
|
||||||
@@ -1916,7 +1857,6 @@ if __name__ == "__main__":
|
|||||||
div_after_gemm()
|
div_after_gemm()
|
||||||
div_channel_broadcast_1024()
|
div_channel_broadcast_1024()
|
||||||
div_runtime_scalar_rhs()
|
div_runtime_scalar_rhs()
|
||||||
div_runtime_scalar_lhs()
|
|
||||||
div_leading_dimension_broadcast()
|
div_leading_dimension_broadcast()
|
||||||
|
|
||||||
print("\nDone.")
|
print("\nDone.")
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user