diff --git a/AGENTS.md b/AGENTS.md index f19d8ed..9cadae5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,92 +1,210 @@ -- Always read the full README.md before doing anything. -- Build commands: - - `cmake --build ./build_release` - - `cmake --build ./build_debug` -- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache. -- Always tries the release version build first and ask before building with the debug version +* Always read the full README.md before doing anything +* Build commands: + * `cmake --build ./build_release` + * `cmake --build ./build_debug` +* Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache +* Always try the release build first before building with the debug version +* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively +* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations + +# Core engineering philosophy + +* Clean architecture matters as much as making the immediate test pass +* Prefer fixes that preserve clear ownership boundaries, explicit invariants, and simple dataflow +* Do not stack compensating fixes on top of earlier mistakes. If the current approach is becoming messy, stop and explain why +* A correct fix should usually make the responsible producer, resolver, verifier, or lowering own the behavior directly +* Avoid late repair passes, defensive cleanup, or broad rewrites when a cleaner owner-side fix is possible +* Do not hide an upstream modeling bug by normalizing it later in the pipeline. Fix the producer when the producer owns the invariant +* Prefer patterns/rewrites for local IR canonicalization. Use module walks only when pass-level structural analysis genuinely requires them +* Prefer compact, structured designs over long case-by-case implementations + +# Think before coding + +* State assumptions explicitly before implementing when they affect the design +* If multiple interpretations exist, present them instead of silently choosing one +* If a simpler approach exists, say so and prefer it unless there is a clear reason not to +* If something is unclear, stop, name what is confusing, and ask +* If the requested or obvious approach would make the architecture worse, push back and propose a cleaner alternative # Code changes -- Keep changes minimal and localized to the relevant parts of the code. -- Preserve the existing naming conventions and coding style used in the surrounding code. -- Keep code easy to read, well organized, and suitable for future extensibility. A function must not be longer than - 200/250 lines for readability and cognitive complexity. -- Prefer clear naming and structure over comments. Add comments only when they materially improve clarity. -- Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change. +* Keep changes minimal and localized to the relevant parts of the code +* Preserve the existing naming conventions and coding style used in the surrounding code +* Keep code easy to read, well organized, and suitable for future extensibility +* A function must not exceed roughly 200/250 lines. If a change pushes a function beyond that, extract focused helpers +* Prefer clear naming and structure over comments. Add comments only when they materially improve clarity +* Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change +* Avoid duplicate ad-hoc logic. If the same concept appears in multiple places, consider whether it deserves a shared helper/API +* When adding a helper or API, ask: + * Could this be useful to another component now + * Is another component already implementing the same idea differently + * Is this likely to be needed by a future adjacent component + * What is the narrowest useful abstraction + * What is the correct ownership level for this API +* If a shared API is justified, place it at the lowest clean layer that can be used by all relevant consumers without creating dependency cycles or leaking policy across layers +* If an existing component should use a newly introduced shared API, refactor that component in the same patch when doing so is directly related and reduces duplication +* Do not create broad frameworks just because a helper might someday be useful. Shared APIs should encode a real reusable concept, not speculative generality +* If the reusable abstraction is plausible but not clearly needed yet, keep the code local and mention the possible future extraction separately + +# Avoid case-listing designs + +* Avoid solving problems with large chains of `if`/`else`, switches, or repeated special cases that enumerate every possible situation +* Long case listings tend to overfit the current tests, grow the codebase, and hide the underlying abstraction +* When you see a growing list of special cases, stop and look for the shared concept, data model, interface, or normalization step that would make the cases collapse +* Prefer table-driven logic, traits/interfaces, small reusable predicates, structured dispatch, or producer-side normalization when they express the invariant more directly +* A few explicit cases are fine when the domain is genuinely small and closed +* If the list is likely to grow, refactor toward a cleaner and more compact design instead of adding another branch +* When keeping a case list is the pragmatic choice, explain why the domain is closed or why a broader abstraction would be premature + +# Ownership and invariants + +Before implementing, identify the owner of the behavior: + +* A producer should emit IR/data that satisfies the contract of the next stage +* A lowering should make representation changes explicit and semantically correct +* A resolver should resolve existing structure without silently changing semantics +* A verifier should reject invalid states with bounded, actionable diagnostics +* Codegen should assume verified invariants and fail clearly if they are violated + +When fixing a bug: + +* State the invariant that was violated +* State which component should own that invariant +* Fix that component directly +* Avoid fixes that merely mask the violation later in the pipeline +* Add or preserve verification if the invariant is important enough to regress + +# Refactor and API policy + +You may propose or implement a refactor when: + +* the local fix would duplicate logic +* the local fix would violate a layer boundary +* the bug exists because responsibility is assigned to the wrong component +* multiple components already implement ad-hoc variants of the same concept +* a shared helper/API would make the code smaller, clearer, and easier to maintain +* existing callers can be migrated cleanly without broad churn +* the current implementation is turning into a long list of special cases instead of a structured solution + +When proposing or implementing a refactor: + +* Explain what responsibility is being moved or shared +* Justify why the new location is the right ownership level +* Keep the API narrow and named after the concept or invariant it represents +* Migrate directly related existing users when that improves compactness and consistency +* Separate changes required for correctness from optional cleanup +* Avoid unrelated renames, formatting changes, or module moves +* Do not expand a justified refactor beyond directly related callers + +Do not refactor when: + +* the issue is truly local and a local fix is clearer +* the abstraction would have only one user and no clear adjacent use +* the abstraction would mix policies from different layers +* the refactor would affect unrelated behavior +* the refactor is mainly aesthetic # Working style -- Infer style and conventions from the existing code before introducing new patterns. -- When several implementation options are possible, prefer the simplest one that fits the current architecture and - minimizes churn. -- Avoid broad refactors unless I explicitly ask for them. +* Infer style and conventions from the existing code before introducing new patterns +* When several implementation options are possible, prefer the simplest one that fits the current architecture and minimizes churn +* Push back when the requested or obvious fix would make the architecture worse +* If a cleaner fix requires a small refactor or shared helper/API, propose it explicitly and justify it +* Avoid broad refactors unless explicitly requested or clearly necessary for correctness and maintainability +* When tests fail, bucket failures by likely root cause and separate patch-related failures from pre-existing or out-of-scope failures -# Responses +# Simplicity first -- When showing code in chat, make it easy to copy-paste into the codebase. -- Keep outputs focused on the changed parts. -- At the end of the response, briefly list any bad practices, mistakes, or cleaner alternatives you noticed, separate - from the main solution. +* Minimum code that solves the problem cleanly. Nothing speculative +* No features beyond what was asked +* No error handling for impossible scenarios +* If you write 200 lines and it could be 50, rewrite it +* Ask: “Would a senior engineer say this is overcomplicated?” If yes, simplify +* Prefer direct, explicit code over generic machinery unless the generic machinery clearly reduces duplication and preserves boundaries -# Guidelines +# Fallbacks and defaults -## 1. Think Before Coding +* Avoid silent fallback behavior when the semantic category is unknown +* Do not treat “unknown” as “safe” unless the codebase already defines that convention +* If a value cannot be classified, either preserve the existing behavior deliberately or fail with a clear diagnostic +* When adding a fallback, state why it is semantically valid and what invariant makes it safe -**Don't assume. Don't hide confusion. Surface tradeoffs.** +# Surgical changes -Before implementing: +* Touch only what you must +* Clean up only the mess introduced by your own change +* Do not “improve” adjacent code, comments, or formatting +* Match existing style, even if you would personally do it differently +* If you notice unrelated dead code, bad abstractions, or fragile design, mention it separately. Do not delete or rewrite it unless asked +* When your changes create orphans, remove imports, variables, functions, or files made unused by your change +* Every changed line should trace directly to the requested fix, a required cleanup, or a justified reuse/refactor decision -- State your assumptions explicitly. If uncertain, ask. -- If multiple interpretations exist, present them - don't pick silently. -- If a simpler approach exists, say so. Push back when warranted. -- If something is unclear, stop. Name what's confusing. Ask. +# Diagnostics and verification -## 2. Simplicity First +* Use existing bounded diagnostic mechanisms for pass-level verification or codegen failures +* Do not emit unbounded repeated diagnostics from loops or parallel workers +* Diagnostics should identify the violated invariant and the relevant value/op when useful +* Verifiers should reject invalid states, not repair them +* Codegen should not compensate for invalid IR/data unless codegen is the owner of that invariant +* Do not make failing tests pass by weakening verifiers, assertions, or diagnostics unless the check itself is proven wrong +* If a check is too strict, explain the valid case it rejects and update the invariant accordingly +* Prefer fixing invalid IR/data producers over relaxing consumers +* If adding diagnostics only for debugging, remove them or cap them before finalizing -**Minimum code that solves the problem. Nothing speculative.** +# Temporary debugging code -- No features beyond what was asked. -- No error handling for impossible scenarios. -- If you write 200 lines and it could be 50, rewrite it. +* Temporary diagnostics, dumps, assertions, and debug-only helpers must be removed or intentionally converted into bounded permanent diagnostics before finalizing +* If debug instrumentation remains, explain why it is useful as permanent infrastructure +* Do not leave noisy validation output behind -Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify. +# Performance awareness -## 3. Surgical Changes +* Avoid algorithmic regressions in compiler passes, especially repeated full-module walks, repeated expensive analyses, or per-op recomputation inside nested loops +* If a change adds a walk, cache, analysis, or structural traversal, justify why it is needed +* For hot paths, prefer preserving existing asymptotic behavior unless a better structure is part of the requested change +* If performance may change, mention the expected impact and suggest a targeted timing check -**Touch only what you must. Clean up only your own mess.** - -When editing existing code: - -- Don't "improve" adjacent code, comments, or formatting. -- Don't refactor things that aren't broken. -- Match existing style, even if you'd do it differently. -- If you notice unrelated dead code, mention it - don't delete it. - -When your changes create orphans: - -- Remove imports/variables/functions that YOUR changes made unused. -- Don't remove pre-existing dead code unless asked, but mention it. - -The test: Every changed line should trace directly to the user's request. - -## 4. Goal-Driven Execution - -**Define success criteria. Loop until verified.** - -Transform tasks into verifiable goals: - -- "Add validation" → "Write tests for invalid inputs, then make them pass" -- "Fix the bug" → "Write a test that reproduces it, then make it pass" -- "Refactor X" → "Ensure tests pass before and after" +# Goal-driven execution For multi-step tasks, state a brief plan: -``` 1. [Step] → verify: [check] 2. [Step] → verify: [check] 3. [Step] → verify: [check] -``` -Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification. +Define success criteria before implementing: ---- +* For bug fixes, success means reproducing or identifying the failure, fixing the responsible owner, and verifying the targeted case +* For refactors, success means preserving behavior while making ownership, reuse, or structure cleaner +* For validation changes, success means checking both valid and invalid cases when applicable + +Transform tasks into verifiable goals: + +* “Fix the bug” → identify the invariant, reproduce the failure, fix the owner, verify the targeted case +* “Add validation” → write or identify tests for invalid inputs, then make them pass/fail as expected +* “Refactor X” → preserve behavior before and after, then run relevant tests + +# Final self-review + +Before reporting completion, check: + +* Did I fix the owner of the invariant rather than masking the issue downstream +* Did I avoid broad case lists and ad-hoc special handling +* Did I introduce a helper/API only at the right ownership level +* Did I migrate directly related duplicate logic when doing so improves compactness +* Did I avoid weakening verifiers or assertions unnecessarily +* Did I remove temporary debugging code or make it bounded and intentional +* Did I avoid unrelated formatting, renames, or cleanup +* Did I consider performance impact for added walks, analyses, caches, or repeated computations +* Did I run the required build/test commands +* Did I clearly report remaining failures or risks + +When reporting back: + +* Say what changed +* Say what was verified +* Say what remains +* When showing code in chat, make it easy to copy-paste into the codebase +* Keep outputs focused on the changed parts +* List bad practices, fragile assumptions, or cleaner alternatives separately +* If a change is intentionally pragmatic rather than architecturally ideal, say so and explain the tradeoff diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 0954a5c..d4dad33 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -7,7 +7,9 @@ #include "llvm/ADT/SmallVector.h" #include +#include +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -29,6 +31,36 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; +struct ConvLoweringState { + Value x; + Value w; + Value b; + RankedTensorType xType; + RankedTensorType wType; + RankedTensorType outType; + int64_t batchSize; + int64_t numChannelsIn; + int64_t xHeight; + int64_t xWidth; + int64_t numChannelsOut; + int64_t wHeight; + int64_t wWidth; + int64_t outHeight; + int64_t outWidth; + int64_t group; + int64_t numChannelsInPerGroup; + int64_t numChannelsOutPerGroup; + int64_t padHeightBegin; + int64_t padHeightEnd; + int64_t padWidthBegin; + int64_t padWidthEnd; + int64_t strideHeight; + int64_t strideWidth; + int64_t dilationHeight; + int64_t dilationWidth; + bool hasBias; +}; + static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) @@ -44,58 +76,621 @@ static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, }); } -static Value createPaddedRows(Value tensorValue, - RankedTensorType tensorType, - int64_t paddedRows, - ConversionPatternRewriter& rewriter, - Location loc) { - if (tensorType.getDimSize(0) == paddedRows) - return tensorValue; +static bool +isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) { + return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0; +} - auto paddedType = RankedTensorType::get( - {paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding()); - SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; - SmallVector highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), - rewriter.getIndexAttr(0)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads); +static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) { + assert(value > 0 && "expected positive value"); + limit = std::min(value, limit); + for (int64_t candidate = limit; candidate >= 1; --candidate) + if (value % candidate == 0) + return candidate; + return 1; +} + +static Value createZeroPaddedTensor(Value value, + RankedTensorType resultType, + ArrayRef lowPadValues, + ArrayRef highPadValues, + ConversionPatternRewriter& rewriter, + Location loc) { + auto valueType = cast(value.getType()); + if (valueType == resultType) + return value; + + SmallVector lowPads; + SmallVector highPads; + lowPads.reserve(lowPadValues.size()); + highPads.reserve(highPadValues.size()); + for (auto lowPad : lowPadValues) + lowPads.push_back(rewriter.getIndexAttr(lowPad)); + for (auto highPad : highPadValues) + highPads.push_back(rewriter.getIndexAttr(highPad)); + + auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); auto* padBlock = new Block(); - for (int i = 0; i < 2; ++i) + for (int64_t dim = 0, rank = resultType.getRank(); dim < rank; ++dim) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = getOrCreateConstant( - rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType()); + rewriter, padOp.getOperation(), rewriter.getZeroAttr(resultType.getElementType()), resultType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } +static Value affineAddConst( + ConversionPatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { + if (offset == 0) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); +} + +static Value createConvInputPatch(Value input, + RankedTensorType patchType, + Value batchIndex, + Value channelOffset, + Value inputHeightOffset, + Value inputWidthOffset, + int64_t dilationHeight, + int64_t dilationWidth, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t patchChannels = patchType.getDimSize(1); + const int64_t kernelHeight = patchType.getDimSize(2); + const int64_t kernelWidth = patchType.getDimSize(3); + if (dilationHeight == 1 && dilationWidth == 1) { + SmallVector offsets {batchIndex, channelOffset, inputHeightOffset, inputWidthOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(patchChannels), + rewriter.getIndexAttr(kernelHeight), + rewriter.getIndexAttr(kernelWidth)}; + return tensor::ExtractSliceOp::create(rewriter, loc, patchType, input, offsets, sizes, getUnitStrides(rewriter, 4)); + } + + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + auto elementType = patchType.getElementType(); + auto pixelType = RankedTensorType::get({1, patchChannels, 1, 1}, elementType, patchType.getEncoding()); + Value patch = tensor::EmptyOp::create(rewriter, loc, patchType.getShape(), elementType); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + Value sourceHeightOffset = affineAddConst(rewriter, loc, inputHeightOffset, kernelH * dilationHeight, anchorOp); + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + Value sourceWidthOffset = affineAddConst(rewriter, loc, inputWidthOffset, kernelW * dilationWidth, anchorOp); + SmallVector sourceOffsets {batchIndex, channelOffset, sourceHeightOffset, sourceWidthOffset}; + SmallVector sourceSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(patchChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value sourcePixel = tensor::ExtractSliceOp::create( + rewriter, loc, pixelType, input, sourceOffsets, sourceSizes, getUnitStrides(rewriter, 4)); + SmallVector targetOffsets { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(kernelH), rewriter.getIndexAttr(kernelW)}; + patch = tensor::InsertSliceOp::create( + rewriter, loc, sourcePixel, patch, targetOffsets, sourceSizes, getUnitStrides(rewriter, 4)); + } + } + return patch; +} + +static Value createCollectedConvOutput(ValueRange gemmRows, + Type convType, + RankedTensorType gemmOutType, + RankedTensorType nhwcType, + RankedTensorType outType, + int64_t numPatches, + int64_t numChannelsOut, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc); + +namespace depthwise { + +struct Tiling { + int64_t outputMultiplier; + int64_t kernelElements; + int64_t channelsPerTile; + int64_t tileInputRows; + int64_t tileOutputChannels; + int64_t numChannelTiles; + int64_t spatialPatchesPerBatch; + int64_t totalPatches; +}; + +static std::optional computeTiling(int64_t batchSize, + int64_t numChannelsIn, + int64_t numChannelsOut, + int64_t wHeight, + int64_t wWidth, + int64_t outHeight, + int64_t outWidth) { + const int64_t kernelElements = wHeight * wWidth; + const int64_t outputMultiplier = numChannelsOut / numChannelsIn; + const int64_t xbarDim = static_cast(crossbarSize.getValue()); + if (kernelElements <= 0 || outputMultiplier <= 0 || kernelElements > xbarDim || outputMultiplier > xbarDim) + return std::nullopt; + + const int64_t maxChannelsPerTile = std::min(xbarDim / kernelElements, xbarDim / outputMultiplier); + if (maxChannelsPerTile <= 0) + return std::nullopt; + + const int64_t channelsPerTile = findLargestDivisorAtMost(numChannelsIn, maxChannelsPerTile); + const int64_t tileInputRows = channelsPerTile * kernelElements; + const int64_t tileOutputChannels = channelsPerTile * outputMultiplier; + if (tileInputRows > xbarDim || tileOutputChannels > xbarDim) + return std::nullopt; + + return Tiling { + outputMultiplier, + kernelElements, + channelsPerTile, + tileInputRows, + tileOutputChannels, + numChannelsIn / channelsPerTile, + outHeight * outWidth, + batchSize * outHeight * outWidth, + }; +} + +static Value buildPackedWeights(DenseElementsAttr wDenseAttr, + RankedTensorType wType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + auto packedWeightType = RankedTensorType::get( + {tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType()); + SmallVector packedValues(packedWeightType.getNumElements(), + cast(rewriter.getZeroAttr(wType.getElementType()))); + SmallVector sourceValues(wDenseAttr.getValues()); + + for (int64_t tileIndex = 0; tileIndex < tiling.numChannelTiles; ++tileIndex) { + const int64_t channelBase = tileIndex * tiling.channelsPerTile; + for (int64_t localChannel = 0; localChannel < tiling.channelsPerTile; ++localChannel) { + const int64_t globalChannel = channelBase + localChannel; + for (int64_t kernelIndex = 0; kernelIndex < tiling.kernelElements; ++kernelIndex) { + const int64_t kernelH = kernelIndex / wType.getDimSize(3); + const int64_t kernelW = kernelIndex % wType.getDimSize(3); + const int64_t targetRow = localChannel * tiling.kernelElements + kernelIndex; + for (int64_t multiplierIndex = 0; multiplierIndex < tiling.outputMultiplier; ++multiplierIndex) { + const int64_t globalOutChannel = globalChannel * tiling.outputMultiplier + multiplierIndex; + const int64_t sourceFlatIndex = + ((globalOutChannel * wType.getDimSize(1) * wType.getDimSize(2)) + kernelH) * wType.getDimSize(3) + kernelW; + const int64_t targetCol = localChannel * tiling.outputMultiplier + multiplierIndex; + const int64_t targetFlatIndex = + ((tileIndex * tiling.tileInputRows) + targetRow) * tiling.tileOutputChannels + targetCol; + packedValues[targetFlatIndex] = sourceValues[sourceFlatIndex]; + } + } + } + } + + auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); +} + +static Value createPaddedInput(Value input, + RankedTensorType inputType, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + ConversionPatternRewriter& rewriter, + Location loc) { + if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0) + return input; + + auto paddedInputType = RankedTensorType::get({inputType.getDimSize(0), + inputType.getDimSize(1), + inputType.getDimSize(2) + padHeightBegin + padHeightEnd, + inputType.getDimSize(3) + padWidthBegin + padWidthEnd}, + inputType.getElementType()); + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { + Value padded = createZeroPaddedTensor(computeInput, + paddedInputType, + {0, 0, padHeightBegin, padWidthBegin}, + {0, 0, padHeightEnd, padWidthEnd}, + rewriter, + loc); + spatial::SpatYieldOp::create(rewriter, loc, padded); + }); + return computeOp.getResult(0); +} + +static Value createInputTile(Value input, + Value patchIndex, + Value channelTileIndex, + RankedTensorType inputTileType, + const Tiling& tiling, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t outWidth, + ConversionPatternRewriter& rewriter, + Location loc) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); + Value batchPatchIndex = affineModConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); + Value outHeightIndex = affineFloorDivConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); + Value outWidthIndex = affineModConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); + Value inputHeightOffset = + strideHeight == 1 ? outHeightIndex : affineMulConst(rewriter, loc, outHeightIndex, strideHeight, anchorOp); + Value inputWidthOffset = + strideWidth == 1 ? outWidthIndex : affineMulConst(rewriter, loc, outWidthIndex, strideWidth, anchorOp); + Value channelOffset = tiling.channelsPerTile == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.channelsPerTile, anchorOp); + Value tile4D = createConvInputPatch(input, + inputTileType, + batchIndex, + channelOffset, + inputHeightOffset, + inputWidthOffset, + dilationHeight, + dilationWidth, + rewriter, + loc); + auto collapsedType = RankedTensorType::get({1, tiling.tileInputRows}, inputTileType.getElementType()); + return tensor::CollapseShapeOp::create(rewriter, + loc, + collapsedType, + tile4D, + SmallVector { + {0}, + {1, 2, 3} + }); +} + +static Value createWeightTile(Value packedWeights, + Value channelTileIndex, + RankedTensorType packedWeightType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling.tileInputRows), + rewriter.getIndexAttr(tiling.tileOutputChannels)}; + auto sliceType = + RankedTensorType::get({1, tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); + Value slice = tensor::ExtractSliceOp::create( + rewriter, loc, sliceType, packedWeights, offsets, sizes, getUnitStrides(rewriter, 3)); + auto collapsedType = + RankedTensorType::get({tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); + return tensor::CollapseShapeOp::create(rewriter, + loc, + collapsedType, + slice, + SmallVector { + {0, 1}, + {2} + }); +} + +static Value createBiasTile( + Value bias, Value channelTileIndex, const Tiling& tiling, ConversionPatternRewriter& rewriter, Location loc) { + auto biasType = cast(bias.getType()); + auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType()); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value channelOffset = tiling.tileOutputChannels == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); + SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; + return tensor::ExtractSliceOp::create(rewriter, loc, biasTileType, bias, offsets, sizes, getUnitStrides(rewriter, 2)); +} + +static Value insertOutputTile(Value rowTile, + Value rowAccumulator, + Value channelTileIndex, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value channelOffset = tiling.tileOutputChannels == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); + SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; + return tensor::InsertSliceOp::create( + rewriter, loc, rowTile, rowAccumulator, offsets, sizes, getUnitStrides(rewriter, 2)); +} + +static FailureOr reconstructDepthwiseGemmRows(Value pieces, + RankedTensorType piecesType, + RankedTensorType gemmOutType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) { + auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType()); + Value outputInit = tensor::EmptyOp::create(rewriter, loc, gemmOutType.getShape(), gemmOutType.getElementType()); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, tiling.totalPatches); + Value cNumChannelTiles = getOrCreateIndexConstant(rewriter, anchorOp, tiling.numChannelTiles); + + auto patchLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cNumPatches, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, + Location nestedLoc, + Value patchIndex, + ValueRange patchIterArgs, + SmallVectorImpl& patchYielded) { + Value outputAcc = patchIterArgs.front(); + Value rowInit = tensor::EmptyOp::create(rewriter, nestedLoc, rowType.getShape(), rowType.getElementType()); + auto tileLoop = buildNormalizedScfFor( + rewriter, + nestedLoc, + c0, + cNumChannelTiles, + c1, + ValueRange {rowInit}, + [&](OpBuilder&, + Location tileLoc, + Value channelTileIndex, + ValueRange tileIterArgs, + SmallVectorImpl& tileYielded) { + Value rowAcc = tileIterArgs.front(); + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + Value laneIndex = createOrFoldAffineApply( + rewriter, tileLoc, (d0 * tiling.totalPatches) + d1, ValueRange {channelTileIndex, patchIndex}, anchorOp); + auto rowTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, piecesType.getElementType()); + SmallVector pieceOffsets {laneIndex, rewriter.getIndexAttr(0)}; + SmallVector pieceSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling.tileOutputChannels)}; + Value rowTile = tensor::ExtractSliceOp::create( + rewriter, tileLoc, rowTileType, piecesArg, pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); + Value rowNext = insertOutputTile(rowTile, rowAcc, channelTileIndex, tiling, rewriter, tileLoc); + tileYielded.push_back(rowNext); + return success(); + }); + if (failed(tileLoop)) + return failure(); + + SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(gemmOutType.getDimSize(1))}; + Value outputNext = tensor::InsertSliceOp::create(rewriter, + nestedLoc, + tileLoop->results.front(), + outputAcc, + rowOffsets, + rowSizes, + getUnitStrides(rewriter, 2)) + .getResult(); + patchYielded.push_back(outputNext); + return success(); + }); + if (failed(patchLoop)) + return failure(); + + spatial::SpatYieldOp::create(rewriter, loc, patchLoop->results.front()); + return success(); + }); + if (failed(collectedOp)) + return failure(); + return collectedOp->getResult(0); +} + +static bool canUseStructuredRewrite(const ConvLoweringState& state) { + if (!getHostConstDenseElementsAttr(state.w)) + return false; + + if (!computeTiling(state.batchSize, + state.numChannelsIn, + state.numChannelsOut, + state.wHeight, + state.wWidth, + state.outHeight, + state.outWidth)) { + return false; + } + + if (isa(state.b.getDefiningOp())) + return true; + + auto biasType = dyn_cast(state.b.getType()); + if (!biasType) + return false; + if (biasType.getRank() == 1) + return biasType.getDimSize(0) == state.numChannelsOut; + if (biasType.getRank() != 2) + return false; + return biasType.getDimSize(0) == 1 && biasType.getDimSize(1) == state.numChannelsOut; +} + +static FailureOr +rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + if (!wDenseAttr) { + convOp.emitOpError("requires constant-derived weights for structured depthwise Spatial lowering"); + return failure(); + } + + auto tiling = computeTiling(state.xType.getDimSize(0), + state.xType.getDimSize(1), + state.outType.getDimSize(1), + state.wType.getDimSize(2), + state.wType.getDimSize(3), + state.outType.getDimSize(2), + state.outType.getDimSize(3)); + if (!tiling) { + convOp.emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering"); + return failure(); + } + + Value paddedInput = createPaddedInput(state.x, + state.xType, + state.padHeightBegin, + state.padHeightEnd, + state.padWidthBegin, + state.padWidthEnd, + rewriter, + loc); + Value packedWeights = buildPackedWeights(wDenseAttr, state.wType, *tiling, rewriter, loc); + + Value expandedBias; + SmallVector batchInputs {paddedInput}; + if (!isa(state.b.getDefiningOp())) { + expandedBias = expandBiasIfNeeded(state.b, rewriter, loc); + auto biasType = dyn_cast(expandedBias.getType()); + if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1 + || biasType.getDimSize(1) != state.outType.getDimSize(1)) { + convOp.emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering"); + return failure(); + } + batchInputs.push_back(expandedBias); + } + + auto gemmOutType = + RankedTensorType::get({tiling->totalPatches, state.outType.getDimSize(1)}, state.outType.getElementType()); + auto piecesType = RankedTensorType::get({tiling->totalPatches * tiling->numChannelTiles, tiling->tileOutputChannels}, + state.outType.getElementType()); + auto paddedInputType = cast(paddedInput.getType()); + auto inputTileType = + RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)}, + paddedInputType.getElementType()); + + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {piecesType}, + tiling->totalPatches * tiling->numChannelTiles, + ValueRange {packedWeights}, + batchInputs, + [&](detail::SpatComputeBatchBodyArgs args) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value patchIndex = affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value channelTileIndex = affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value inputTile = createInputTile(args.inputs.front(), + patchIndex, + channelTileIndex, + inputTileType, + *tiling, + state.strideHeight, + state.strideWidth, + state.dilationHeight, + state.dilationWidth, + state.outType.getDimSize(3), + rewriter, + loc); + Value weightTile = createWeightTile(args.weights.front(), + channelTileIndex, + cast(args.weights.front().getType()), + *tiling, + rewriter, + loc); + auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType()); + Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult(); + if (args.inputs.size() > 1) { + Value biasTile = createBiasTile(args.inputs[1], channelTileIndex, *tiling, rewriter, loc); + rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult(); + } + + SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling->tileOutputChannels)}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); + }); + if (failed(batchOp)) + return failure(); + + auto nhwcType = RankedTensorType::get( + {state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)}, + state.outType.getElementType()); + auto collectedRows = + reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc); + if (failed(collectedRows)) + return failure(); + + return createCollectedConvOutput(ValueRange {*collectedRows}, + state.outType, + gemmOutType, + nhwcType, + state.outType, + tiling->totalPatches, + state.outType.getDimSize(1), + /*packFactor=*/1, + rewriter, + loc); +} + +} // namespace depthwise + +namespace standard { + +struct ConvGemmPlan { + int64_t patchSize; + int64_t numPatchesPerBatch; + int64_t numPatches; + int64_t maxParallelPixels; + int64_t effectiveMaxParallelPixels; + int64_t packedNumRows; + + RankedTensorType im2colType; + RankedTensorType im2colRowType; + RankedTensorType gemmInputRowsType; + RankedTensorType wFlatType; + RankedTensorType wTransType; + RankedTensorType gemmOutType; + RankedTensorType gemmOutputRowsType; + RankedTensorType nhwcType; +}; + +static Value createPaddedRows(Value rows, + RankedTensorType rowsType, + int64_t paddedRows, + ConversionPatternRewriter& rewriter, + Location loc) { + if (rowsType.getDimSize(0) == paddedRows) + return rows; + + auto paddedType = + RankedTensorType::get({paddedRows, rowsType.getDimSize(1)}, rowsType.getElementType(), rowsType.getEncoding()); + return createZeroPaddedTensor( + rows, paddedType, {0, 0}, {paddedRows - rowsType.getDimSize(0), 0}, rewriter, loc); +} + static Value packRowsForParallelGemm( Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { if (packFactor == 1) return rows; - const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor); - const int64_t paddedNumRows = packedNumRows * packFactor; + const int64_t paddedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor) * packFactor; + const int64_t packedNumRows = paddedNumRows / packFactor; const int64_t rowWidth = rowsType.getDimSize(1); auto groupedType = RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); auto packedType = RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); - Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); - Value groupedRows = tensor::ExpandShapeOp::create(rewriter, - loc, - groupedType, - paddedRows, - SmallVector { - {0, 1}, - {2} + Value padded = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); + Value grouped = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + padded, + SmallVector { + {0, 1}, + {2} }); return tensor::CollapseShapeOp::create(rewriter, loc, packedType, - groupedRows, + grouped, SmallVector { {0}, {1, 2} @@ -121,62 +716,82 @@ static Value unpackRowsFromParallelGemm(Value packedRows, auto unpackedType = RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); - Value expandedRows = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - packedRows, - SmallVector { - {0}, - {1, 2} + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedRows, + SmallVector { + {0}, + {1, 2} }); - Value paddedRows = tensor::CollapseShapeOp::create(rewriter, - loc, - paddedType, - expandedRows, - SmallVector { - {0, 1}, - {2} + Value padded = tensor::CollapseShapeOp::create(rewriter, + loc, + paddedType, + expanded, + SmallVector { + {0, 1}, + {2} }); if (paddedNumRows == unpackedRows) - return paddedRows; + return padded; SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)}; - SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides); + return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, padded, offsets, sizes, getUnitStrides(rewriter, 2)); } -static Value buildPackedWeight(DenseElementsAttr wDenseAttr, - Value wTrans, - RankedTensorType wType, - int64_t numChannelsIn, - int64_t numChannelsOut, - int64_t wHeight, - int64_t wWidth, - int64_t patchSize, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - if (packFactor == 1) +static Value createWeightMatrix( + Value weights, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { + auto buildWeightMatrix = [&](Value weight) -> Value { + Value flattened = tensor::CollapseShapeOp::create(rewriter, + loc, + plan.wFlatType, + weight, + SmallVector { + {0}, + {1, 2, 3} + }); + return ONNXTransposeOp::create(rewriter, loc, plan.wTransType, flattened, rewriter.getI64ArrayAttr({1, 0})) + .getResult(); + }; + + if (isCompileTimeComputable(weights)) + return buildWeightMatrix(weights); + + auto computeOp = + createSpatCompute<1>(rewriter, loc, TypeRange {plan.wTransType}, {}, ValueRange {weights}, [&](Value weight) { + spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); + }); + return computeOp.getResult(0); +} + +static Value buildPackedWeights(DenseElementsAttr wDenseAttr, + Value wTrans, + const ConvLoweringState& state, + const ConvGemmPlan& plan, + ConversionPatternRewriter& rewriter, + Location loc) { + if (plan.effectiveMaxParallelPixels == 1) return wTrans; - auto packedWeightType = - RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType()); + auto packedWeightType = RankedTensorType::get( + {plan.effectiveMaxParallelPixels * plan.patchSize, plan.effectiveMaxParallelPixels * state.numChannelsOut}, + state.wType.getElementType()); SmallVector sourceValues(wDenseAttr.getValues()); SmallVector packedValues(packedWeightType.getNumElements(), - cast(rewriter.getZeroAttr(wType.getElementType()))); + cast(rewriter.getZeroAttr(state.wType.getElementType()))); - for (int64_t copyId = 0; copyId < packFactor; copyId++) { - for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) { - for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) { - for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) { - for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) { + for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) { + for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) { + for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) { + for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) { + for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) { const int64_t sourceFlatIndex = - (((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW; - const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW; - const int64_t targetRow = copyId * patchSize + patchIndex; - const int64_t targetCol = copyId * numChannelsOut + outChannel; - packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex]; + (((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW; + const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW; + const int64_t targetRow = copyId * plan.patchSize + patchIndex; + const int64_t targetCol = copyId * state.numChannelsOut + outChannel; + packedValues[targetRow * packedWeightType.getDimSize(1) + targetCol] = sourceValues[sourceFlatIndex]; } } } @@ -187,124 +802,95 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr, return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); } -static Value createConvWeightMatrix( - Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) { - auto buildWeightMatrix = [&](Value weight) -> Value { - Value wFlat = tensor::CollapseShapeOp::create(rewriter, - loc, - wFlatType, - weight, - SmallVector { - {0}, - {1, 2, 3} - }); - return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult(); - }; - - if (isCompileTimeComputable(w)) - return buildWeightMatrix(w); - - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) { - spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); - }); - return computeOp.getResult(0); -} - -static Value buildPackedBias(bool hasBias, - Value gemmBias, +static Value buildPackedBias(Value gemmBias, Value biasMatrix, DenseElementsAttr biasDenseAttr, - RankedTensorType outType, - int64_t numChannelsOut, - int64_t packFactor, + const ConvLoweringState& state, + const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { - if (!hasBias) + if (!state.hasBias) return gemmBias; - if (packFactor == 1) + if (plan.effectiveMaxParallelPixels == 1) return biasMatrix; SmallVector sourceValues(biasDenseAttr.getValues()); SmallVector packedValues; - packedValues.reserve(packFactor * numChannelsOut); - for (int64_t copyId = 0; copyId < packFactor; copyId++) + packedValues.reserve(plan.effectiveMaxParallelPixels * state.numChannelsOut); + for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) packedValues.append(sourceValues.begin(), sourceValues.end()); - auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); + auto packedBiasType = + RankedTensorType::get({1, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.outType.getElementType()); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType); } -static Value createIm2colRowComputes(Value x, - RankedTensorType xType, - RankedTensorType im2colType, - RankedTensorType im2colRowType, - RankedTensorType gemmInputRowsType, - int64_t batchSize, - int64_t numChannelsIn, - int64_t xHeight, - int64_t xWidth, - int64_t wHeight, - int64_t wWidth, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - int64_t outWidth, - int64_t patchSize, - int64_t numPatches, - int64_t numPatchesPerBatch, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - auto elemType = xType.getElementType(); +static ConvGemmPlan +buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants, bool canPackBiasAsConstants) { + ConvGemmPlan plan; + plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth; + plan.numPatchesPerBatch = state.outHeight * state.outWidth; + plan.numPatches = state.batchSize * plan.numPatchesPerBatch; + const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut); + plan.maxParallelPixels = std::max(1, static_cast(crossbarSize.getValue()) / wMaxDim); + plan.effectiveMaxParallelPixels = + (canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1; + plan.packedNumRows = ceilIntegerDivide(plan.numPatches, plan.effectiveMaxParallelPixels); + + auto elemType = state.xType.getElementType(); + auto outElemType = state.outType.getElementType(); + plan.im2colType = RankedTensorType::get({plan.numPatches, plan.patchSize}, elemType); + plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType); + plan.gemmInputRowsType = + RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType); + plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType()); + plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType()); + plan.gemmOutType = RankedTensorType::get({plan.numPatches, state.numChannelsOut}, outElemType); + plan.gemmOutputRowsType = + RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType); + plan.nhwcType = + RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, outElemType); + return plan; +} + +static Value createIm2colRows( + const ConvLoweringState& state, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { constexpr size_t numInputs = 1; auto im2colComputeOp = - createSpatCompute(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { + createSpatCompute(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, state.x, [&](Value xArg) { + auto elemType = state.xType.getElementType(); Value paddedInput = xArg; - - // Pad input with zeros if needed: - // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] - if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { - const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; - const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightBegin), - rewriter.getIndexAttr(padWidthBegin)}; - SmallVector highPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightEnd), - rewriter.getIndexAttr(padWidthEnd)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); - auto* padBlock = new Block(); - for (int i = 0; i < 4; i++) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType); - tensor::YieldOp::create(rewriter, loc, zero); - rewriter.setInsertionPointAfter(padOp); - paddedInput = padOp.getResult(); + if (state.padHeightBegin || state.padHeightEnd || state.padWidthBegin || state.padWidthEnd) { + auto paddedInputType = RankedTensorType::get( + {state.batchSize, + state.numChannelsIn, + state.xHeight + state.padHeightBegin + state.padHeightEnd, + state.xWidth + state.padWidthBegin + state.padWidthEnd}, + elemType); + paddedInput = createZeroPaddedTensor(paddedInput, + paddedInputType, + {0, 0, state.padHeightBegin, state.padWidthBegin}, + {0, 0, state.padHeightEnd, state.padWidthEnd}, + rewriter, + loc); } - // Build im2col [numPatches, patchSize] incrementally to keep the IR small - // until the late PIM unrolling step. - Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); + // Keep the standard im2col view of convolution, flipped so filters sit in + // B / crossbar columns: + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] + // Gemm output: [numPatches, cOut] + Value im2colInit = tensor::EmptyOp::create(rewriter, loc, plan.im2colType.getShape(), elemType); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); - auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); - auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches); - auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch); - auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth); - auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); - auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatches); + Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch); + Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); + Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight); + Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth); auto im2colLoop = buildNormalizedScfFor( rewriter, @@ -322,44 +908,44 @@ static Value createIm2colRowComputes(Value x, Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); - SmallVector offsets = { - batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = - tensor::ExtractSliceOp::create(rewriter, nestedLoc, patchType, paddedInput, offsets, sizes, strides); - + auto patchType = + RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); + Value patch = createConvInputPatch(paddedInput, + patchType, + batchIndex, + c0, + inputHeightOffset, + inputWidthOffset, + state.dilationHeight, + state.dilationWidth, + rewriter, + nestedLoc); Value row = tensor::CollapseShapeOp::create(rewriter, nestedLoc, - im2colRowType, + plan.im2colRowType, patch, SmallVector { {0}, {1, 2, 3} }); - SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; - SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; - SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value updatedIm2col = - tensor::InsertSliceOp::create(rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); - yielded.push_back(updatedIm2col); + SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(plan.patchSize)}; + Value next = tensor::InsertSliceOp::create( + rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); + yielded.push_back(next); return success(); }); if (failed(im2colLoop)) return failure(); - Value im2col = im2colLoop->results.front(); - Value gemmInputRows = im2col; - if (packFactor != 1) - gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); + Value gemmInputRows = im2colLoop->results.front(); + // Pack N old im2col rows into one longer row so one GEMM can cover N + // pixels in parallel. The corresponding packed weight matrix contains N + // block-diagonal copies of W^T, and the packed output must be unpacked + // back to one row per spatial patch. + if (plan.effectiveMaxParallelPixels != 1) + gemmInputRows = packRowsForParallelGemm(gemmInputRows, plan.im2colType, plan.effectiveMaxParallelPixels, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); return success(); @@ -369,6 +955,52 @@ static Value createIm2colRowComputes(Value x, return im2colComputeOp->getResult(0); } +static Value rewriteConv(const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value biasMatrix; + DenseElementsAttr biasDenseAttr; + if (state.hasBias) { + gemmBias = state.b; + biasDenseAttr = getHostConstDenseElementsAttr(state.b); + biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + } + + ConvGemmPlan plan = + buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr)); + // Prepare weight matrix W for crossbar storage: + // W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout] + Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc); + Value gemmInputRows = createIm2colRows(state, plan, rewriter, loc); + Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); + Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); + + Value gemmRows = ONNXGemmOp::create(rewriter, + loc, + plan.gemmOutputRowsType, + gemmInputRows, + gemmB, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + + return createCollectedConvOutput(ValueRange {gemmRows}, + state.outType, + plan.gemmOutType, + plan.nhwcType, + state.outType, + plan.numPatches, + state.numChannelsOut, + plan.effectiveMaxParallelPixels, + rewriter, + loc); +} + +} // namespace standard + static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, @@ -386,19 +1018,14 @@ static Value createCollectedConvOutput(ValueRange gemmRows, } else { Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); - gemmOut = unpackRowsFromParallelGemm(packedOutput, - cast(packedOutput.getType()), - numPatches, - numChannelsOut, - packFactor, - rewriter, - loc); + gemmOut = standard::unpackRowsFromParallelGemm( + packedOutput, cast(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); } - // Restore to NCHW layout: + // Restore output layout: // [numPatches, numChannelsOut] - // -> [1, outHeight, outWidth, numChannelsOut] - // -> [1, numChannelsOut, outHeight, outWidth] + // -> [N, Hout, Wout, Cout] + // -> [N, Cout, Hout, Wout] Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, loc, nhwcType, @@ -413,232 +1040,82 @@ static Value createCollectedConvOutput(ValueRange gemmRows, return collectComputeOp.getResult(0); } -static Value lowerSingleConvGroup(Value x, - Value w, - Value b, - RankedTensorType xType, - RankedTensorType wType, - RankedTensorType outType, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - ConversionPatternRewriter& rewriter, - Location loc) { - const int64_t batchSize = xType.getDimSize(0); - const int64_t numChannelsIn = xType.getDimSize(1); - const int64_t xHeight = xType.getDimSize(2); - const int64_t xWidth = xType.getDimSize(3); - const int64_t numChannelsOut = wType.getDimSize(0); - const int64_t wHeight = wType.getDimSize(2); - const int64_t wWidth = wType.getDimSize(3); - const int64_t outHeight = outType.getDimSize(2); - const int64_t outWidth = outType.getDimSize(3); +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) { + ConvLoweringState state; + state.x = convOpAdaptor.getX(); + state.w = convOpAdaptor.getW(); + state.b = convOpAdaptor.getB(); + state.xType = cast(state.x.getType()); + state.wType = cast(state.w.getType()); + state.outType = cast(convOp.getY().getType()); - // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // Gemm output: [numPatches, cOut] - const int64_t patchSize = numChannelsIn * wHeight * wWidth; - const int64_t numPatchesPerBatch = outHeight * outWidth; - const int64_t numPatches = batchSize * numPatchesPerBatch; - - auto elemType = xType.getElementType(); - auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); - auto rowType = RankedTensorType::get({1, patchSize}, elemType); - auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); - auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); - auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); - auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); - - const int64_t xbarSize = static_cast(crossbarSize.getValue()); - const int64_t wMaxDim = std::max(patchSize, numChannelsOut); - const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); - auto wDenseAttr = getHostConstDenseElementsAttr(w); - - // Prepare weight matrix W for crossbar storage: - // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] - Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc); - - // Pass bias through directly; Gemm handles rank-1 C canonicalization. - bool hasB = !isa(b.getDefiningOp()); - Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - Value biasMatrix; - DenseElementsAttr biasDenseAttr; - if (hasB) { - gemmBias = b; - biasDenseAttr = getHostConstDenseElementsAttr(b); - biasMatrix = expandBiasIfNeeded(b, rewriter, loc); - } - const bool canPackWeightsAsConstants = static_cast(wDenseAttr); - const bool canPackBiasAsConstants = !hasB || static_cast(biasDenseAttr); - const int64_t effectiveMaxParallelPixels = - (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; - - // Keep the standard im2col view of convolution: - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // and optionally repack several old rows into one GEMM row to use the available crossbar size better. - // - // We want to process N pixels at the same time. Instead of doing N separate operations - // of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix - // containing N copies of W^T and concatenate N im2col rows into one longer row: - // A_packed: [ceil(numPatches / N), N * patchSize] - // B_packed: [N * patchSize, N * cOut] - // Y_packed: [ceil(numPatches / N), N * cOut] - const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); - auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType); - auto gemmOutputRowsType = - RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); - Value gemmInputRows = createIm2colRowComputes(x, - xType, - im2colType, - rowType, - gemmInputRowsType, - batchSize, - numChannelsIn, - xHeight, - xWidth, - wHeight, - wWidth, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - outWidth, - patchSize, - numPatches, - numPatchesPerBatch, - effectiveMaxParallelPixels, - rewriter, - loc); - - Value gemmB = buildPackedWeight(wDenseAttr, - wTrans, - wType, - numChannelsIn, - numChannelsOut, - wHeight, - wWidth, - patchSize, - effectiveMaxParallelPixels, - rewriter, - loc); - Value gemmC = buildPackedBias( - hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); - - Value gemmRows = ONNXGemmOp::create(rewriter, - loc, - gemmOutputRowsType, - gemmInputRows, - gemmB, - gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - - return createCollectedConvOutput(ValueRange {gemmRows}, - outType, - gemmOutType, - nhwcType, - outType, - numPatches, - numChannelsOut, - effectiveMaxParallelPixels, - rewriter, - loc); -} - -} // namespace - -LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, - ONNXConvOpAdaptor convOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location loc = convOp.getLoc(); - Value x = convOpAdaptor.getX(); - Value w = convOpAdaptor.getW(); - Value b = convOpAdaptor.getB(); - - auto xType = cast(x.getType()); - auto wType = cast(w.getType()); - auto outType = cast(convOp.getY().getType()); - - if (!xType.hasStaticShape()) { + if (!state.xType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); return failure(); } - if (!wType.hasStaticShape()) { + if (!state.wType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); return failure(); } - if (!outType.hasStaticShape()) { + if (!state.outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); return failure(); } - if (xType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4}); + if (state.xType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv input", state.xType.getRank(), {4}); return failure(); } - if (wType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4}); + if (state.wType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", state.wType.getRank(), {4}); return failure(); } - if (outType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4}); + if (state.outType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv result", state.outType.getRank(), {4}); return failure(); } - if (convOp.getGroup() < 1) { + + state.group = convOp.getGroup(); + if (state.group < 1) { convOp.emitOpError("requires group >= 1 for Spatial lowering"); return failure(); } - const int64_t batchSize = xType.getDimSize(0); - const int64_t numChannelsIn = xType.getDimSize(1); - const int64_t xHeight = xType.getDimSize(2); - const int64_t xWidth = xType.getDimSize(3); - const int64_t numChannelsOut = wType.getDimSize(0); - const int64_t wHeight = wType.getDimSize(2); - const int64_t wWidth = wType.getDimSize(3); - const int64_t outHeight = outType.getDimSize(2); - const int64_t outWidth = outType.getDimSize(3); - const int64_t group = convOp.getGroup(); - const bool hasB = !isa(b.getDefiningOp()); + state.batchSize = state.xType.getDimSize(0); + state.numChannelsIn = state.xType.getDimSize(1); + state.xHeight = state.xType.getDimSize(2); + state.xWidth = state.xType.getDimSize(3); + state.numChannelsOut = state.wType.getDimSize(0); + state.wHeight = state.wType.getDimSize(2); + state.wWidth = state.wType.getDimSize(3); + state.outHeight = state.outType.getDimSize(2); + state.outWidth = state.outType.getDimSize(3); + state.hasBias = !isa(state.b.getDefiningOp()); - if (numChannelsIn % group != 0) { - convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group + if (state.numChannelsIn % state.group != 0) { + convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group " + << state.group << " for Spatial lowering"; + return failure(); + } + if (state.numChannelsOut % state.group != 0) { + convOp.emitOpError() << "requires output channels " << state.numChannelsOut << " to be divisible by group " + << state.group << " for Spatial lowering"; + return failure(); + } + + state.numChannelsInPerGroup = state.numChannelsIn / state.group; + state.numChannelsOutPerGroup = state.numChannelsOut / state.group; + if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) { + convOp.emitOpError() << "requires grouped conv weight input channels " << state.wType.getDimSize(1) + << " to match input channels per group " << state.numChannelsInPerGroup << " for Spatial lowering"; return failure(); } - if (numChannelsOut % group != 0) { - convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group - << " for Spatial lowering"; + if (state.wType.getDimSize(0) != state.numChannelsOut) { + convOp.emitOpError() << "requires weight output channels " << state.wType.getDimSize(0) + << " to match result channels " << state.numChannelsOut << " for Spatial lowering"; return failure(); } - const int64_t numChannelsInPerGroup = numChannelsIn / group; - const int64_t numChannelsOutPerGroup = numChannelsOut / group; - if (wType.getDimSize(1) != numChannelsInPerGroup) { - convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1) - << " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering"; - return failure(); - } - if (wType.getDimSize(0) != numChannelsOut) { - convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels " - << numChannelsOut << " for Spatial lowering"; - return failure(); - } - - // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) const auto stridesAttr = convOp.getStrides(); const auto dilationsAttr = convOp.getDilations(); const auto padsAttr = convOp.getPads(); @@ -656,79 +1133,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return failure(); } - const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); - const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); - const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); - const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); - - int64_t padHeightBegin = 0; - int64_t padHeightEnd = 0; - int64_t padWidthBegin = 0; - int64_t padWidthEnd = 0; + state.strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); + state.strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); + state.dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); + state.dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); + state.padHeightBegin = 0; + state.padHeightEnd = 0; + state.padWidthBegin = 0; + state.padWidthEnd = 0; if (padsAttr) { - padHeightBegin = getI64Attr(*padsAttr, 0); - padWidthBegin = getI64Attr(*padsAttr, 1); - padHeightEnd = getI64Attr(*padsAttr, 2); - padWidthEnd = getI64Attr(*padsAttr, 3); + state.padHeightBegin = getI64Attr(*padsAttr, 0); + state.padWidthBegin = getI64Attr(*padsAttr, 1); + state.padHeightEnd = getI64Attr(*padsAttr, 2); + state.padWidthEnd = getI64Attr(*padsAttr, 3); + return state; } - else { - // Compute padding from auto_pad attribute - const auto autoPad = convOp.getAutoPad(); - if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; - const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; - const int64_t totalPadH = - std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); - const int64_t totalPadW = - std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); - if (autoPad == "SAME_UPPER") { - padHeightBegin = totalPadH / 2; - padHeightEnd = totalPadH - padHeightBegin; - padWidthBegin = totalPadW / 2; - padWidthEnd = totalPadW - padWidthBegin; - } - else { // SAME_LOWER - padHeightEnd = totalPadH / 2; - padHeightBegin = totalPadH - padHeightEnd; - padWidthEnd = totalPadW / 2; - padWidthBegin = totalPadW - padWidthEnd; - } + const auto autoPad = convOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (state.wHeight - 1) * state.dilationHeight + 1; + const int64_t effectiveKernelW = (state.wWidth - 1) * state.dilationWidth + 1; + const int64_t totalPadH = + std::max(static_cast(0), (state.outHeight - 1) * state.strideHeight + effectiveKernelH - state.xHeight); + const int64_t totalPadW = + std::max(static_cast(0), (state.outWidth - 1) * state.strideWidth + effectiveKernelW - state.xWidth); + + if (autoPad == "SAME_UPPER") { + state.padHeightBegin = totalPadH / 2; + state.padHeightEnd = totalPadH - state.padHeightBegin; + state.padWidthBegin = totalPadW / 2; + state.padWidthEnd = totalPadW - state.padWidthBegin; } - else if (autoPad != "NOTSET" && autoPad != "VALID") { - convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; - return failure(); + else { + state.padHeightEnd = totalPadH / 2; + state.padHeightBegin = totalPadH - state.padHeightEnd; + state.padWidthEnd = totalPadW / 2; + state.padWidthBegin = totalPadW - state.padWidthEnd; } - // "NOTSET" or "VALID" -> all pads stay 0 + return state; } - if (group == 1) { - rewriter.replaceOp(convOp, - lowerSingleConvGroup(x, - w, - b, - xType, - wType, - outType, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - rewriter, - loc)); - return success(); + if (autoPad != "NOTSET" && autoPad != "VALID") { + convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; + return failure(); } - SmallVector xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc); - SmallVector wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc); + return state; +} + +static LogicalResult +rewriteUngroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + rewriter.replaceOp(convOp, standard::rewriteConv(state, rewriter, convOp.getLoc())); + return success(); +} + +static LogicalResult +rewriteDepthwiseConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + FailureOr result = depthwise::rewriteConv(convOp, state, rewriter, convOp.getLoc()); + if (failed(result)) + return failure(); + + rewriter.replaceOp(convOp, *result); + return success(); +} + +static ConvLoweringState makeGroupedConvLoweringState( + const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) { + ConvLoweringState state = parent; + state.x = groupX; + state.w = groupW; + state.b = groupB; + state.xType = cast(groupX.getType()); + state.wType = cast(groupW.getType()); + state.outType = groupOutType; + state.batchSize = state.xType.getDimSize(0); + state.numChannelsIn = state.xType.getDimSize(1); + state.xHeight = state.xType.getDimSize(2); + state.xWidth = state.xType.getDimSize(3); + state.numChannelsOut = state.wType.getDimSize(0); + state.wHeight = state.wType.getDimSize(2); + state.wWidth = state.wType.getDimSize(3); + state.outHeight = state.outType.getDimSize(2); + state.outWidth = state.outType.getDimSize(3); + state.group = 1; + state.numChannelsInPerGroup = state.numChannelsIn; + state.numChannelsOutPerGroup = state.numChannelsOut; + state.hasBias = !isa(groupB.getDefiningOp()); + return state; +} + +static LogicalResult +rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, convOp.getLoc()); + SmallVector wSlices = + sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); SmallVector bSlices; - if (hasB) { - auto biasType = cast(b.getType()); + if (state.hasBias) { + auto biasType = cast(state.b.getType()); int64_t biasAxis = -1; if (biasType.getRank() == 1) biasAxis = 0; @@ -739,50 +1241,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, << biasType.getRank(); return failure(); } - bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc); + bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); } - if (xSlices.size() != static_cast(group) || wSlices.size() != static_cast(group) - || (hasB && bSlices.size() != static_cast(group))) { + if (xSlices.size() != static_cast(state.group) || wSlices.size() != static_cast(state.group) + || (state.hasBias && bSlices.size() != static_cast(state.group))) { convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering"); return failure(); } SmallVector groupResults; - groupResults.reserve(group); - auto groupOutType = - RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType()); - Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - for (int64_t groupId = 0; groupId < group; groupId++) { + groupResults.reserve(state.group); + auto groupOutType = RankedTensorType::get( + {state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType()); + Value noBias = ONNXNoneOp::create(rewriter, convOp.getLoc(), rewriter.getNoneType()); + for (int64_t groupId = 0; groupId < state.group; groupId++) { Value groupX = xSlices[groupId]; Value groupW = wSlices[groupId]; - Value groupB = hasB ? bSlices[groupId] : noBias; - groupResults.push_back(lowerSingleConvGroup(groupX, - groupW, - groupB, - cast(groupX.getType()), - cast(groupW.getType()), - groupOutType, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - rewriter, - loc)); + Value groupB = state.hasBias ? bSlices[groupId] : noBias; + ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType); + groupResults.push_back(standard::rewriteConv(groupState, rewriter, convOp.getLoc())); } Value result; if (llvm::all_of(groupResults, isCompileTimeComputable)) { - result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); + result = createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, groupResults); } else { - auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args)); - }); + auto concatCompute = + createSpatCompute(rewriter, convOp.getLoc(), TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) { + spatial::SpatYieldOp::create( + rewriter, convOp.getLoc(), createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, args)); + }); result = concatCompute.getResult(0); } @@ -790,6 +1280,27 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return success(); } +} // namespace + +LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const { + FailureOr state = analyzeConvLoweringState(convOp, convOpAdaptor); + if (failed(state)) + return failure(); + + if (isDepthwiseConv(state->group, state->numChannelsIn, state->numChannelsOut, state->numChannelsInPerGroup)) { + if (depthwise::canUseStructuredRewrite(*state)) + return rewriteDepthwiseConv(convOp, *state, rewriter); + return rewriteGroupedConv(convOp, *state, rewriter); + } + + if (state->group == 1) + return rewriteUngroupedConv(convOp, *state, rewriter); + + return rewriteGroupedConv(convOp, *state, rewriter); +} + void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index e068259..e6914d5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -690,11 +690,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); - if (gemmOpAdaptor.getTransA()) { - gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering"); - return failure(); - } - auto aType = dyn_cast(a.getType()); auto bType = dyn_cast(b.getType()); auto outType = dyn_cast(gemmOp.getY().getType()); @@ -725,9 +720,12 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, return failure(); } - const int64_t numOutRows = outType.getDimSize(0); - const int64_t numOutCols = outType.getDimSize(1); - const int64_t reductionSize = aType.getDimSize(1); + if (gemmOpAdaptor.getTransA()) { + auto aShape = aType.getShape(); + auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding()); + a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult(); + aType = transposedType; + } if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); @@ -736,6 +734,10 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, bType = transposedType; } + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + const int64_t reductionSize = aType.getDimSize(1); + if (!isCompileTimeComputable(b)) { bool hasC = hasGemmBias(c); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 922c05f..4b888b1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -22,13 +22,87 @@ namespace { static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, ArrayRef rhsBatchShape) { - if (lhsBatchShape.empty()) - return SmallVector(rhsBatchShape.begin(), rhsBatchShape.end()); - if (rhsBatchShape.empty()) - return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); - if (!llvm::equal(lhsBatchShape, rhsBatchShape)) - return failure(); - return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); + const int64_t resultRank = std::max(lhsBatchShape.size(), rhsBatchShape.size()); + SmallVector resultShape(resultRank, 1); + for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1; + resultIndex >= 0; + --resultIndex, --lhsIndex, --rhsIndex) { + const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1; + const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1; + if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1) + return failure(); + resultShape[resultIndex] = std::max(lhsDim, rhsDim); + } + return resultShape; +} + +static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape) { + if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) + return 0; + if (llvm::equal(sourceBatchShape, outputBatchShape)) + return outputBatchIndex; + + SmallVector outputStrides = computeRowMajorStrides(outputBatchShape); + SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); + int64_t sourceFlatIndex = 0; + for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { + if (sourceBatchShape[sourceDimIndex] == 1) + continue; + const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; + const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex]; + const int64_t outputDimIndexValue = outputDimStride == 1 + ? outputBatchIndex % outputBatchShape[outputDimIndex] + : (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex]; + sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex]; + } + return sourceFlatIndex; +} + +static Value computeFlatBatchIndexCoordinate( + Value flatBatchIndex, ArrayRef batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) { + if (batchShape[dimIndex] == 1) + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + + const int64_t dimStride = dimIndex + 1 == static_cast(batchShape.size()) + ? 1 + : getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1)); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value dimCoordinate = flatBatchIndex; + if (dimStride != 1) + dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp); + return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp); +} + +static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + PatternRewriter& rewriter, + Location loc) { + if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + if (llvm::equal(sourceBatchShape, outputBatchShape)) + return outputBatchIndex; + + Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); + for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { + if (sourceBatchShape[sourceDimIndex] == 1) + continue; + const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; + Value outputCoordinate = + computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc); + Value contribution = sourceStrides[sourceDimIndex] == 1 + ? outputCoordinate + : affineMulConst(rewriter, + loc, + outputCoordinate, + sourceStrides[sourceDimIndex], + rewriter.getInsertionBlock()->getParentOp()); + sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution); + } + return sourceBatchIndex; } static Value @@ -67,6 +141,52 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded); } +static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { + auto buildExpanded = [&](Value input) -> Value { + return tensor::ExpandShapeOp::create(rewriter, + loc, + resultType, + input, + SmallVector { + {0, 1} + }); + }; + return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded); +} + +static SmallVector buildCollapseReassociation(ArrayRef removedAxes) { + SmallVector reassociation; + ReassociationIndices currentGroup; + for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) { + currentGroup.push_back(axis); + if (!removeAxis) { + reassociation.push_back(currentGroup); + currentGroup.clear(); + } + } + + if (!currentGroup.empty()) { + if (reassociation.empty()) + reassociation.push_back(std::move(currentGroup)); + else + reassociation.back().append(currentGroup.begin(), currentGroup.end()); + } + return reassociation; +} + +static Value squeezeUnitDims( + Value value, RankedTensorType resultType, ArrayRef removedAxes, PatternRewriter& rewriter, Location loc) { + if (cast(value.getType()) == resultType) + return value; + + SmallVector reassociation = + resultType.getRank() == 0 ? SmallVector {} : buildCollapseReassociation(removedAxes); + auto buildCollapsed = [&](Value input) -> Value { + return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult(); + }; + return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed); +} + static Value ensureBatchedTensor( Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); @@ -171,8 +291,11 @@ static Value createPaddedBatchedInputCompute(Value input, return computeOp.getResult(0); } -static FailureOr materializePaddedBatchedWeight( - Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) { +static FailureOr materializePaddedBatchedWeight(Value value, + ArrayRef sourceBatchShape, + ArrayRef targetBatchShape, + RankedTensorType resultType, + PatternRewriter& rewriter) { auto sourceType = cast(value.getType()); if (sourceType == resultType) return value; @@ -183,13 +306,15 @@ static FailureOr materializePaddedBatchedWeight( const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1); const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2); + const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape); const int64_t targetRows = resultType.getDimSize(1); const int64_t targetCols = resultType.getDimSize(2); SmallVector sourceValues(denseAttr.getValues()); SmallVector resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType())); for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) { - const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx); + const int64_t sourceBatchIdx = + sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape); const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols; const int64_t targetBatchBase = batchIdx * targetRows * targetCols; for (int64_t row = 0; row < sourceRows; ++row) @@ -202,16 +327,18 @@ static FailureOr materializePaddedBatchedWeight( } static Value extractBatchedATile(Value a, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value row, Value kOffset, RankedTensorType aTileType, PatternRewriter& rewriter, Location loc) { auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType()); - SmallVector offsets { - sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), row, kOffset}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))}; auto slice = @@ -227,8 +354,9 @@ static Value extractBatchedATile(Value a, } static Value extractBatchedBTile(Value b, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value kOffset, Value hOffset, RankedTensorType bTileType, @@ -236,8 +364,9 @@ static Value extractBatchedBTile(Value b, Location loc) { auto bSliceType = RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType()); - SmallVector offsets { - sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(bTileType.getDimSize(0)), rewriter.getIndexAttr(bTileType.getDimSize(1))}; @@ -262,9 +391,10 @@ static Value getBatchLaneIndex( static FailureOr createBatchedVmmBatch(Value a, Value b, RankedTensorType aType, - int64_t aBatchCount, + ArrayRef aBatchShape, RankedTensorType bType, - int64_t bBatchCount, + ArrayRef bBatchShape, + ArrayRef outputBatchShape, RankedTensorType partialPiecesType, int64_t numOutRows, int64_t numKSlices, @@ -298,10 +428,10 @@ static FailureOr createBatchedVmmBatch(Value a, auto pieceType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); - Value aTile = - extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc); - Value bTile = - extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc); + Value aTile = extractBatchedATile( + args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc); + Value bTile = extractBatchedBTile( + args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc); Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; @@ -315,17 +445,17 @@ static FailureOr createBatchedVmmBatch(Value a, } static Value extractDynamicBatchedBColumn(Value matrix, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value column, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType()); - SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) - : OpFoldResult(batch), - rewriter.getIndexAttr(0), - column}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -350,17 +480,17 @@ static Value extractDynamicBatchedBColumn(Value matrix, } static Value extractDynamicBatchedRowVector(Value matrix, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value row, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); - SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) - : OpFoldResult(batch), - row, - rewriter.getIndexAttr(0)}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; auto rowSlice = @@ -376,9 +506,10 @@ static Value extractDynamicBatchedRowVector(Value matrix, } static FailureOr createBatchedVvdmulBatch(Value a, - int64_t aBatchCount, + ArrayRef aBatchShape, Value b, - int64_t bBatchCount, + ArrayRef bBatchShape, + ArrayRef outputBatchShape, RankedTensorType aType, RankedTensorType bType, RankedTensorType scalarPiecesType, @@ -406,10 +537,10 @@ static FailureOr createBatchedVvdmulBatch(Value a, auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); - Value aVector = - extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc); - Value bVector = - extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc); + Value aVector = extractDynamicBatchedRowVector( + args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc); + Value bVector = extractDynamicBatchedBColumn( + args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -629,11 +760,17 @@ static FailureOr createBatchedReductionCompute(Value partialPieces, return computeOp->getResult(0); } -struct MatMulShapeInfo { +struct NormalizedMatMulInfo { RankedTensorType lhsType; RankedTensorType rhsType; RankedTensorType outType; - SmallVector batchShape; + RankedTensorType normalizedLhsType; + RankedTensorType normalizedRhsType; + SmallVector lhsBatchShape; + SmallVector rhsBatchShape; + SmallVector outputBatchShape; + bool lhsWasVector; + bool rhsWasVector; int64_t lhsBatch; int64_t rhsBatch; int64_t batch; @@ -642,46 +779,170 @@ struct MatMulShapeInfo { int64_t n; }; -static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { +struct MatMulLoweringPlan { + Value lhs; + Value rhs; + RankedTensorType lhsType; + RankedTensorType rhsType; + SmallVector lhsBatchShape; + SmallVector rhsBatchShape; + SmallVector outputBatchShape; + int64_t lhsBatch; + int64_t rhsBatch; + int64_t batch; + int64_t m; + int64_t k; + int64_t n; + bool transposedResult; +}; + +static SmallVector computeExpectedMatMulOutputShape( + ArrayRef batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) { + SmallVector shape(batchShape.begin(), batchShape.end()); + if (lhsWasVector && rhsWasVector) + return shape; + if (lhsWasVector) { + shape.push_back(n); + return shape; + } + if (rhsWasVector) { + shape.push_back(m); + return shape; + } + shape.push_back(m); + shape.push_back(n); + return shape; +} + +static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); - if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) + if (lhsType.getRank() < 1 || rhsType.getRank() < 1) return failure(); if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) return failure(); - SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); - SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); - auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); - if (failed(batchShape)) + const bool lhsWasVector = lhsType.getRank() == 1; + const bool rhsWasVector = rhsType.getRank() == 1; + auto normalizedLhsType = + lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding()) + : lhsType; + auto normalizedRhsType = + rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding()) + : rhsType; + + SmallVector lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2); + SmallVector rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2); + auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); + if (failed(outputBatchShape)) return failure(); const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); - const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); - const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); - const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); - const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); - const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); + const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape); + const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2); + const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1); + const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2); + const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1); if (k != rhsK) return failure(); - if (outType.getRank() == 2) { - if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) - return failure(); - } - else { - SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); - if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m - || outType.getDimSize(outType.getRank() - 1) != n) - return failure(); + if (SmallVector(outType.getShape().begin(), outType.getShape().end()) + != computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) { + return failure(); } - return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n}; + return NormalizedMatMulInfo {lhsType, + rhsType, + outType, + normalizedLhsType, + normalizedRhsType, + lhsBatchShape, + rhsBatchShape, + *outputBatchShape, + lhsWasVector, + rhsWasVector, + lhsBatch, + rhsBatch, + batch, + m, + k, + n}; +} + +static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs, + Value normalizedRhs, + const NormalizedMatMulInfo& info, + bool useTransposedForm, + PatternRewriter& rewriter, + Location loc) { + MatMulLoweringPlan plan {normalizedLhs, + normalizedRhs, + cast(normalizedLhs.getType()), + cast(normalizedRhs.getType()), + info.lhsBatchShape, + info.rhsBatchShape, + info.outputBatchShape, + info.lhsBatch, + info.rhsBatch, + info.batch, + info.m, + info.k, + info.n, + false}; + if (!useTransposedForm) + return plan; + + plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc); + plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc); + plan.lhsType = cast(plan.lhs.getType()); + plan.rhsType = cast(plan.rhs.getType()); + std::swap(plan.lhsBatchShape, plan.rhsBatchShape); + std::swap(plan.lhsBatch, plan.rhsBatch); + plan.m = info.n; + plan.n = info.m; + plan.transposedResult = true; + return plan; +} + +static Value normalizeMatMulOperand( + Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) { + if (!wasVector) + return value; + return createMatrixFromVector(value, normalizedType, rewriter, loc); +} + +static Value finalizeNormalizedMatMulResult(Value value, + RankedTensorType directOutType, + const NormalizedMatMulInfo& info, + PatternRewriter& rewriter, + Location loc) { + // The direct lowered result is always [flatBatch, normalizedM, normalizedN]. + // Restore ONNX MatMul result rank by expanding right-aligned batch dimensions + // and removing the synthetic unit matrix axes introduced for vector operands. + Value result = value; + RankedTensorType currentType = directOutType; + if (info.outputBatchShape.size() > 1) { + SmallVector expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end()); + expandedShape.push_back(info.m); + expandedShape.push_back(info.n); + auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding()); + result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc); + currentType = expandedType; + } + + SmallVector removedAxes(currentType.getRank(), false); + if (info.outputBatchShape.empty()) + removedAxes[0] = true; + if (info.lhsWasVector) + removedAxes[currentType.getRank() - 2] = true; + if (info.rhsWasVector) + removedAxes[currentType.getRank() - 1] = true; + return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc); } struct MatMulToGemm : OpRewritePattern { @@ -689,7 +950,7 @@ struct MatMulToGemm : OpRewritePattern { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto shapeInfo = analyzeMatMulShape(matmulOp); - if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2) + if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); @@ -742,61 +1003,56 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { auto shapeInfo = analyzeMatMulShape(matmulOp); if (failed(shapeInfo)) return failure(); - if (shapeInfo->outType.getRank() == 2) + if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); - bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); + bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector + && isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); - Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); - Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); - int64_t lhsBatchForGemm = shapeInfo->lhsBatch; - int64_t rhsBatchForGemm = shapeInfo->rhsBatch; - int64_t gemmM = shapeInfo->m; - int64_t gemmK = shapeInfo->k; - int64_t gemmN = shapeInfo->n; - if (useTransposedForm) { - lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc); - lhsBatchForGemm = shapeInfo->rhsBatch; - rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); - rhsBatchForGemm = shapeInfo->lhsBatch; - gemmM = shapeInfo->n; - gemmN = shapeInfo->m; - } + Value lhs = + normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc); + Value rhs = + normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc); + lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); + rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); + MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc); - lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); - rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); - auto lhsBatchedType = cast(lhs.getType()); - auto rhsBatchedType = cast(rhs.getType()); - auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType()); + plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc); + plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc); + plan.lhsType = cast(plan.lhs.getType()); + plan.rhsType = cast(plan.rhs.getType()); + auto directOutType = RankedTensorType::get( + {plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); - if (isCompileTimeComputable(rhs)) { - const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue()); - const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue()); + if (isCompileTimeComputable(plan.rhs)) { + const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue()); + const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue()); const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); auto paddedLhsType = RankedTensorType::get( - {lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding()); - auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols}, - rhsBatchedType.getElementType(), - rhsBatchedType.getEncoding()); + {plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding()); + auto paddedRhsType = RankedTensorType::get( + {plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding()); auto paddedOutType = - RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType()); + RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType()); - auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter); + auto paddedRhs = + materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter); if (succeeded(paddedRhs)) { - Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc); - const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices; + Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); + const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices; auto partialPiecesType = RankedTensorType::get({laneCount, static_cast(crossbarSize.getValue())}, shapeInfo->outType.getElementType()); auto batchOp = createBatchedVmmBatch(paddedLhs, *paddedRhs, paddedLhsType, - lhsBatchForGemm, + plan.lhsBatchShape, paddedRhsType, - rhsBatchForGemm, + plan.rhsBatchShape, + plan.outputBatchShape, partialPiecesType, - gemmM, + plan.m, numKSlices, numOutHSlices, rewriter, @@ -807,34 +1063,35 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { partialPiecesType, directOutType, paddedOutType, - shapeInfo->batch, + plan.batch, numKSlices, rewriter, loc); if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) { - auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + if (plan.transposedResult) { + auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } - finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } } - const int64_t laneCount = shapeInfo->batch * gemmM * gemmN; + const int64_t laneCount = plan.batch * plan.m * plan.n; auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType()); - auto batchOp = createBatchedVvdmulBatch(lhs, - lhsBatchForGemm, - rhs, - rhsBatchForGemm, - lhsBatchedType, - rhsBatchedType, + auto batchOp = createBatchedVvdmulBatch(plan.lhs, + plan.lhsBatchShape, + plan.rhs, + plan.rhsBatchShape, + plan.outputBatchShape, + plan.lhsType, + plan.rhsType, scalarPiecesType, directOutType, rewriter, @@ -846,15 +1103,15 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) { - auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + if (plan.transposedResult) { + auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } - finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index bf4ddab..2f03115 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -238,14 +238,8 @@ static Value squeezeReducedAxes(Value keepdimsValue, ArrayRef reducedAxes, ConversionPatternRewriter& rewriter, Location loc) { - if (resultType.getRank() == 0) { - SmallVector indices(cast(keepdimsValue.getType()).getRank(), - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0)); - Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices); - return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); - } - - auto reassociation = buildCollapseReassociation(reducedAxes); + SmallVector reassociation = + resultType.getRank() == 0 ? SmallVector {} : buildCollapseReassociation(reducedAxes); if (isCompileTimeComputable(keepdimsValue)) return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); diff --git a/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx b/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx deleted file mode 100644 index 4d138c1..0000000 Binary files a/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx and /dev/null differ diff --git a/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx b/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx index 7df7a8e..53e138f 100644 Binary files a/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx and b/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx differ diff --git a/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx b/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx deleted file mode 100644 index 3401328..0000000 Binary files a/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx and /dev/null differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 49f491b..ba09a7a 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -262,7 +262,7 @@ def conv_grouped_many_groups(): X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 2, 2]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 2, 2]) W = numpy_helper.from_array( - np.random.default_rng(77).uniform(-1, 1, (1024, 64, 1, 1)).astype(np.float32), name="W") + np.random.default_rng(77).uniform(-1, 1, (1024, 16, 1, 1)).astype(np.float32), name="W") node = helper.make_node("Conv", ["X", "W"], ["Y"], kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0], group=64) graph = helper.make_graph([node], "conv_grouped_many_groups", [X], [Y], initializer=[W]) @@ -779,28 +779,6 @@ def matmul_matrix_vector(): save_model(model, "matmul/matrix_vector", "matmul_matrix_vector.onnx") -def matmul_vector_vector_dot(): - """Vector-vector MatMul producing a scalar output.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1024]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, []) - B = numpy_helper.from_array(np.random.default_rng(97).uniform(-1, 1, (1024,)).astype(np.float32), name="B") - node = helper.make_node("MatMul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "matmul_vector_vector_dot", [A], [Y], initializer=[B]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "matmul/vector_vector_dot", "matmul_vector_vector_dot.onnx") - - -def matmul_batched_4d_broadcast(): - """Batched 4D MatMul with broadcast across leading dimensions.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 1, 3, 4]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 5, 3, 6]) - B = numpy_helper.from_array(np.random.default_rng(98).uniform(-1, 1, (1, 5, 4, 6)).astype(np.float32), name="B") - node = helper.make_node("MatMul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "matmul_batched_4d_broadcast", [A], [Y], initializer=[B]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "matmul/batched_4d_broadcast", "matmul_batched_4d_broadcast.onnx") - - # --------------------------------------------------------------------------- # Pooling tests # --------------------------------------------------------------------------- @@ -1560,17 +1538,6 @@ def add_channel_broadcast_1024(): save_model(model, "add/channel_broadcast_1024", "add_channel_broadcast_1024.onnx") -def add_scalar_runtime(): - """Elementwise Add with a runtime scalar RHS.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1]) - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - node = helper.make_node("Add", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "add_scalar_runtime", [A, B], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "add/scalar_runtime", "add_scalar_runtime.onnx") - - def add_leading_dimension_broadcast(): """Elementwise Add with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1635,17 +1602,6 @@ def mul_channel_broadcast_1024(): save_model(model, "mul/channel_broadcast_1024", "mul_channel_broadcast_1024.onnx") -def mul_scalar_runtime(): - """Elementwise Mul with a runtime scalar RHS.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1]) - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - node = helper.make_node("Mul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "mul_scalar_runtime", [A, B], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "mul/scalar_runtime", "mul_scalar_runtime.onnx") - - def mul_leading_dimension_broadcast(): """Elementwise Mul with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1721,17 +1677,6 @@ def div_runtime_scalar_rhs(): save_model(model, "div/runtime_scalar_rhs", "div_runtime_scalar_rhs.onnx") -def div_runtime_scalar_lhs(): - """Elementwise Div with a scalar constant numerator.""" - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1024, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - A = numpy_helper.from_array(np.asarray([[[[2.0]]]], dtype=np.float32), name="A") - node = helper.make_node("Div", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "div_runtime_scalar_lhs", [B], [Y], initializer=[A]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "div/runtime_scalar_lhs", "div_runtime_scalar_lhs.onnx") - - def div_leading_dimension_broadcast(): """Elementwise Div with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1812,8 +1757,6 @@ if __name__ == "__main__": matmul_huge_1024() matmul_vector_matrix() matmul_matrix_vector() - matmul_vector_vector_dot() - matmul_batched_4d_broadcast() print("\nGenerating Pooling tests:") maxpool_basic() @@ -1899,7 +1842,6 @@ if __name__ == "__main__": add_broadcast_row() add_after_gemm() add_channel_broadcast_1024() - add_scalar_runtime() add_leading_dimension_broadcast() print("\nGenerating Mul tests:") @@ -1907,7 +1849,6 @@ if __name__ == "__main__": mul_scalar_constant() mul_after_conv() mul_channel_broadcast_1024() - mul_scalar_runtime() mul_leading_dimension_broadcast() print("\nGenerating Div tests:") @@ -1916,7 +1857,6 @@ if __name__ == "__main__": div_after_gemm() div_channel_broadcast_1024() div_runtime_scalar_rhs() - div_runtime_scalar_lhs() div_leading_dimension_broadcast() print("\nDone.") diff --git a/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx b/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx deleted file mode 100644 index 25382b8..0000000 Binary files a/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx and /dev/null differ diff --git a/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx b/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx deleted file mode 100644 index cf96880..0000000 Binary files a/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx and /dev/null differ diff --git a/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx b/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx deleted file mode 100644 index 55e63e0..0000000 Binary files a/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx and /dev/null differ