Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp
T
NiccoloN 4a98e88e97
Validate Operations / validate-operations (push) Waiting to run
less affine code and better affine helpers
2026-06-29 14:34:31 +02:00

3846 lines
184 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cctype>
#include <map>
#include <mutex>
#include <optional>
#include <string>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
struct ConvLoweringDecision {
PimConvLoweringType strategy;
std::string reason;
bool isAuto = false;
std::string fallbackReason;
std::string rejectedAutoStrategy;
};
struct PreparedConvInput {
Value value;
RankedTensorType type;
};
struct ConvStrategyEstimate {
uint64_t estimatedMvmCount = 0;
uint64_t estimatedReductionVAddCount = 0;
uint64_t estimatedOutputFragments = 1;
bool perOutputPositionReduction = false;
bool requiresFuncReturnMaterialization = false;
bool giantCollectorConcatExpected = false;
bool fullInputBroadcastExpected = false;
uint64_t concatOperandCount = 0;
std::string materializationKind = "none";
std::string collectorCore = "-";
};
struct ConvReportEntry {
uint64_t id;
std::string where;
std::string strategy;
std::string mode;
std::string inputShape;
std::string weightShape;
std::string outputShape;
int64_t groups;
int64_t k;
int64_t c;
int64_t p;
int64_t xbarSize;
int64_t pack;
uint64_t im2colElements;
uint64_t im2colBudget;
std::string chunkText;
int64_t batchSize;
int64_t numberOfBatches;
std::string spatialComputeBatch;
std::string batchedInstructionEmission;
std::string reason;
std::string fallbackReason;
std::string rejectedAutoStrategy;
uint64_t estimatedMvmCount;
uint64_t estimatedReductionVAddCount;
uint64_t estimatedOutputFragments;
std::string materializationRequiredAtReturn;
std::string materializationKind;
uint64_t concatOperandCount;
std::string collectorCore;
std::string giantCollectorConcatExpected;
std::string fullInputBroadcastExpected;
};
enum class DistributedTensorOpKind {
Relu,
Sigmoid,
Add,
Sub,
Mul,
Div,
Conv,
};
enum class DistributedConvBarrierKind {
Return,
UnsupportedConsumer,
Fanout,
DeadValue,
GroupedConv,
Depthwise,
};
enum class DistributedTensorConstantKind {
None,
Splat,
PerChannel,
};
enum class DistributedTensorLayoutKind {
NchwRowStrip,
};
struct DistributedFragmentInfo {
SmallVector<int64_t, 4> offsets;
SmallVector<int64_t, 4> sizes;
SmallVector<int64_t, 4> strides;
int64_t producerLane = 0;
};
struct DistributedTensorInfo {
Value storage;
RankedTensorType logicalType;
DistributedTensorLayoutKind layoutKind = DistributedTensorLayoutKind::NchwRowStrip;
SmallVector<DistributedFragmentInfo, 8> fragments;
int64_t laneCount = 0;
int64_t fragmentHeight = 1;
int64_t channels = 0;
int64_t height = 0;
int64_t width = 0;
bool isRowStripNchw() const { return layoutKind == DistributedTensorLayoutKind::NchwRowStrip; }
};
struct DistributedTensorRegistry {
llvm::DenseMap<Value, DistributedTensorInfo> infos;
void bind(Value value, const DistributedTensorInfo& info) { infos[value] = info; }
const DistributedTensorInfo* lookup(Value value) const {
auto it = infos.find(value);
if (it == infos.end())
return nullptr;
return &it->second;
}
};
struct DistributedTensorStep {
Operation* op = nullptr;
DistributedTensorOpKind kind;
DenseElementsAttr constantAttr;
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
bool fragmentOnLhs = true;
std::optional<ConvLoweringState> convState;
};
struct DistributedConvAnalysis {
SmallVector<DistributedTensorStep, 4> steps;
Operation* replacementOp = nullptr;
DistributedConvBarrierKind barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
std::string barrierDetail;
bool hasLocalConsumers() const { return !steps.empty(); }
bool hasDistributedConvConsumer() const {
return llvm::any_of(steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; });
}
};
struct DistributedChainReportEntry {
uint64_t chainId = 0;
uint64_t chainLength = 0;
std::string producerKind;
std::string distributedOps;
std::string materializationPoints;
std::string firstMaterializationReason;
uint64_t maxLiveFragments = 0;
uint64_t maxFragmentFanout = 0;
uint64_t patchBuilderCoreCount = 0;
std::string centralJunctionDetected = "no";
std::string convInputMaterializationKind = "dense_materialization_fallback";
uint64_t localPatchFragments = 0;
uint64_t remotePatchFragments = 0;
uint64_t haloTransferCount = 0;
uint64_t groupedTransferCount = 0;
std::string fallbackReason;
};
struct DistributedConvReportTotals {
uint64_t totalConvs = 0;
uint64_t distributedTensorsCreated = 0;
uint64_t distributedValuesPropagated = 0;
uint64_t distributedConsumersHandled = 0;
uint64_t distributedConvInputsSeen = 0;
uint64_t distributedConvInputsConsumed = 0;
uint64_t materializationBarriersInserted = 0;
std::map<std::string, uint64_t> fallbackReasons;
std::map<std::string, uint64_t> barrierReasons;
SmallVector<DistributedChainReportEntry, 16> chains;
};
static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter);
static FailureOr<Value> createRowStripPackedRows(Value rows,
const ConvLoweringState& state,
PatternRewriter& rewriter,
Location loc);
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
switch (kind) {
case DistributedConvBarrierKind::Return: return "func.return";
case DistributedConvBarrierKind::UnsupportedConsumer: return "unsupported_consumer";
case DistributedConvBarrierKind::Fanout: return "fanout";
case DistributedConvBarrierKind::DeadValue: return "dead_value";
case DistributedConvBarrierKind::GroupedConv: return "grouped_conv";
case DistributedConvBarrierKind::Depthwise: return "depthwise_conv";
}
llvm_unreachable("unknown distributed conv barrier kind");
}
static StringRef stringifyConvLoweringStrategy(PimConvLoweringType strategy) {
switch (strategy) {
case PimConvLoweringAuto: return "auto";
case PimConvLoweringLegacy: return "legacy";
case PimConvLoweringDepthwise: return "depthwise";
case PimConvLoweringPackedIm2Col: return "packed-im2col";
case PimConvLoweringStreamedPatch: return "streamed-patch";
case PimConvLoweringStreamedPacked: return "streamed-packed";
case PimConvLoweringOutputChannelTiled: return "output-channel-tiled";
case PimConvLoweringInputKTiled: return "input-k-tiled";
case PimConvLoweringTiled2D: return "tiled-2d";
}
llvm_unreachable("unknown conv lowering strategy");
}
static bool requiresFuncReturnMaterialization(const DistributedConvAnalysis& analysis) {
return !analysis.hasLocalConsumers() && analysis.barrierKind == DistributedConvBarrierKind::Return;
}
static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
PimConvLoweringType strategy,
const DistributedConvAnalysis& analysis) {
ConvStrategyEstimate estimate;
estimate.requiresFuncReturnMaterialization = requiresFuncReturnMaterialization(analysis);
estimate.materializationKind = estimate.requiresFuncReturnMaterialization ? "func.return" : "none";
switch (strategy) {
case PimConvLoweringLegacy:
case PimConvLoweringPackedIm2Col:
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
break;
case PimConvLoweringStreamedPatch:
case PimConvLoweringStreamedPacked:
case PimConvLoweringOutputChannelTiled: {
uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1);
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
estimate.estimatedOutputFragments =
std::max<uint64_t>(1, static_cast<uint64_t>(ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions))));
break;
}
case PimConvLoweringInputKTiled: {
const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize);
const uint64_t maxLanesPerBatch =
std::max<uint64_t>(1,
static_cast<uint64_t>(crossbarCountInCore.getValue())
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
const uint64_t rowChunkWidth = std::max<uint64_t>(
1,
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
maxLanesPerBatch,
static_cast<uint64_t>(std::max<int64_t>(1, geo.outWidth))}));
estimate.estimatedMvmCount =
static_cast<uint64_t>(std::max<int64_t>(1, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(1, numKSlices));
estimate.estimatedReductionVAddCount =
static_cast<uint64_t>(std::max<int64_t>(1, geo.p))
* static_cast<uint64_t>(std::max<int64_t>(0, numKSlices - 1) + (geo.hasBias ? 1 : 0));
estimate.estimatedOutputFragments = static_cast<uint64_t>(std::max<int64_t>(1, geo.batchSize))
* static_cast<uint64_t>(std::max<int64_t>(1, geo.outHeight))
* static_cast<uint64_t>(
ceilIntegerDivide(geo.outWidth, static_cast<int64_t>(rowChunkWidth)));
estimate.perOutputPositionReduction = numKSlices > 1;
estimate.fullInputBroadcastExpected = estimate.estimatedOutputFragments > 1;
if (estimate.requiresFuncReturnMaterialization && estimate.estimatedOutputFragments >= 128) {
estimate.giantCollectorConcatExpected = true;
estimate.materializationKind = "single_collector_concat";
estimate.concatOperandCount = estimate.estimatedOutputFragments;
estimate.collectorCore = "scheduled_post_merge";
}
break;
}
case PimConvLoweringTiled2D:
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
break;
case PimConvLoweringDepthwise:
case PimConvLoweringAuto:
break;
}
if (estimate.requiresFuncReturnMaterialization && estimate.materializationKind == "none")
estimate.materializationKind = "func.return";
return estimate;
}
static std::string formatShape(ArrayRef<int64_t> dims) {
std::string text;
llvm::raw_string_ostream os(text);
os << "[";
for (size_t i = 0; i < dims.size(); ++i) {
if (i != 0)
os << "x";
os << dims[i];
}
os << "]";
return text;
}
static std::string collapseWhitespace(StringRef text) {
std::string out;
out.reserve(text.size());
bool lastWasSpace = false;
for (char c : text) {
bool isSpace = std::isspace(static_cast<unsigned char>(c));
if (isSpace) {
if (!lastWasSpace && !out.empty())
out.push_back(' ');
lastWasSpace = true;
continue;
}
out.push_back(c);
lastWasSpace = false;
}
return out;
}
static std::string abbreviate(StringRef text, size_t maxLen) {
if (text.size() <= maxLen)
return text.str();
return (text.take_front(maxLen - 3) + "...").str();
}
static std::string abbreviateFromEnd(StringRef text, size_t maxLen) {
if (text.size() <= maxLen)
return text.str();
return ("..." + text.take_back(maxLen - 3)).str();
}
static std::string summarizeLocation(Location loc, size_t maxLen = 44) {
std::string text;
llvm::raw_string_ostream os(text);
loc.print(os);
os.flush();
std::string collapsed = collapseWhitespace(text);
if (collapsed.size() <= maxLen)
return collapsed;
if (collapsed.find('/') != std::string::npos || collapsed.find('#') != std::string::npos)
return abbreviateFromEnd(collapsed, maxLen);
return abbreviate(collapsed, maxLen);
}
static std::string alignCell(StringRef text, size_t width, bool rightAlign = false) {
std::string cell = text.str();
if (cell.size() < width) {
size_t padding = width - cell.size();
if (rightAlign)
cell.insert(cell.begin(), padding, ' ');
else
cell.append(padding, ' ');
}
return cell;
}
static bool hasSameStaticTensorType(Value value, Type expectedType) {
auto valueType = dyn_cast<RankedTensorType>(value.getType());
auto expectedTensorType = dyn_cast<RankedTensorType>(expectedType);
return valueType && expectedTensorType && valueType.hasStaticShape() && valueType == expectedTensorType;
}
static bool isSplatConstantValue(Value value, DenseElementsAttr& denseAttr) {
denseAttr = getHostConstDenseElementsAttr(value);
return static_cast<bool>(denseAttr) && denseAttr.isSplat();
}
static bool isPerChannelConstantValue(Value value, RankedTensorType currentType, DenseElementsAttr& denseAttr) {
denseAttr = getHostConstDenseElementsAttr(value);
if (!denseAttr || denseAttr.isSplat())
return false;
auto constantType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!constantType || !constantType.hasStaticShape())
return false;
const int64_t channels = currentType.getDimSize(1);
if (constantType.getRank() == 1)
return constantType.getDimSize(0) == channels;
if (constantType.getRank() == 2)
return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels;
if (constantType.getRank() != 4)
return false;
return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels
&& constantType.getDimSize(2) == 1 && constantType.getDimSize(3) == 1;
}
static std::optional<DistributedTensorStep>
classifyDistributedBinaryConsumer(Operation* user,
Value currentValue,
Value lhs,
Value rhs,
DistributedTensorOpKind kind,
bool allowFragmentOnRhs,
std::string& failureDetail) {
if (user->getNumResults() != 1 || !hasSameStaticTensorType(user->getResult(0), currentValue.getType())) {
failureDetail = "result type mismatch";
return std::nullopt;
}
auto currentType = cast<RankedTensorType>(currentValue.getType());
DenseElementsAttr constantAttr;
if (lhs == currentValue) {
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
if (isSplatConstantValue(rhs, constantAttr))
constantKind = DistributedTensorConstantKind::Splat;
else if (isPerChannelConstantValue(rhs, currentType, constantAttr))
constantKind = DistributedTensorConstantKind::PerChannel;
else
failureDetail = "unsupported rhs broadcast";
if (constantKind == DistributedTensorConstantKind::None)
return std::nullopt;
return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/true, std::nullopt};
}
if (rhs == currentValue && allowFragmentOnRhs) {
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
if (isSplatConstantValue(lhs, constantAttr))
constantKind = DistributedTensorConstantKind::Splat;
else if (isPerChannelConstantValue(lhs, currentType, constantAttr))
constantKind = DistributedTensorConstantKind::PerChannel;
else
failureDetail = "unsupported lhs broadcast";
if (constantKind == DistributedTensorConstantKind::None)
return std::nullopt;
return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/false, std::nullopt};
}
failureDetail = "conv result is not the supported binary operand";
return std::nullopt;
}
static bool covers(RowInterval acquired, RowInterval needed) {
return acquired.begin <= needed.begin && acquired.end >= needed.end;
}
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
if (state.batchSize != 1) {
failureReason = "unsupported_batch";
return false;
}
if (state.group != 1) {
failureReason = "unsupported_groups";
return false;
}
if (isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup)) {
failureReason = "unsupported_depthwise";
return false;
}
if (state.strideHeight != 1 || state.strideWidth != 1) {
failureReason = "unsupported_stride";
return false;
}
if (state.dilationHeight != 1 || state.dilationWidth != 1) {
failureReason = "unsupported_dilation";
return false;
}
if (state.padHeightBegin != state.padHeightEnd || state.padWidthBegin != state.padWidthEnd) {
failureReason = "unsupported_padding";
return false;
}
if (state.padHeightBegin != 1 || state.padWidthBegin != 1) {
failureReason = "unsupported_padding";
return false;
}
if (state.wHeight != 3 || state.wWidth != 3) {
failureReason = "unsupported_kernel";
return false;
}
if (state.outHeight != state.xHeight || state.outWidth != state.xWidth) {
failureReason = "unsupported_output_shape";
return false;
}
if (!getHostConstDenseElementsAttr(state.w)) {
failureReason = "non_constant_weights";
return false;
}
if (state.hasBias && !getHostConstDenseElementsAttr(state.b)) {
failureReason = "non_constant_bias";
return false;
}
failureReason = "";
return true;
}
static std::string stringifyDistributedTensorOpKind(DistributedTensorOpKind kind) {
switch (kind) {
case DistributedTensorOpKind::Relu: return "Relu";
case DistributedTensorOpKind::Sigmoid: return "Sigmoid";
case DistributedTensorOpKind::Add: return "Add";
case DistributedTensorOpKind::Sub: return "Sub";
case DistributedTensorOpKind::Mul: return "Mul";
case DistributedTensorOpKind::Div: return "Div";
case DistributedTensorOpKind::Conv: return "Conv";
}
llvm_unreachable("unknown distributed tensor op kind");
}
[[maybe_unused]] static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) {
DistributedConvAnalysis analysis;
analysis.replacementOp = convOp;
Value currentValue = convOp.getResult();
while (true) {
if (currentValue.use_empty()) {
analysis.barrierKind = DistributedConvBarrierKind::DeadValue;
analysis.barrierDetail = "result has no users";
return analysis;
}
if (!currentValue.hasOneUse()) {
analysis.barrierKind = DistributedConvBarrierKind::Fanout;
analysis.barrierDetail = "value has multiple users";
return analysis;
}
Operation* user = *currentValue.getUsers().begin();
if (isa<func::ReturnOp>(user)) {
analysis.barrierKind = DistributedConvBarrierKind::Return;
analysis.barrierDetail = "materialize at func.return";
return analysis;
}
std::optional<DistributedTensorStep> step;
std::string failureDetail;
if (auto reluOp = dyn_cast<ONNXReluOp>(user)) {
if (hasSameStaticTensorType(reluOp.getResult(), currentValue.getType()))
step = DistributedTensorStep {
user, DistributedTensorOpKind::Relu, {}, DistributedTensorConstantKind::None, true, std::nullopt};
else
failureDetail = "relu result type mismatch";
}
else if (auto sigmoidOp = dyn_cast<ONNXSigmoidOp>(user)) {
if (hasSameStaticTensorType(sigmoidOp.getResult(), currentValue.getType()))
step = DistributedTensorStep {
user, DistributedTensorOpKind::Sigmoid, {}, DistributedTensorConstantKind::None, true, std::nullopt};
else
failureDetail = "sigmoid result type mismatch";
}
else if (auto addOp = dyn_cast<ONNXAddOp>(user)) {
step = classifyDistributedBinaryConsumer(
user, currentValue, addOp.getA(), addOp.getB(), DistributedTensorOpKind::Add, /*allowFragmentOnRhs=*/true,
failureDetail);
}
else if (auto subOp = dyn_cast<ONNXSubOp>(user)) {
step = classifyDistributedBinaryConsumer(
user, currentValue, subOp.getA(), subOp.getB(), DistributedTensorOpKind::Sub, /*allowFragmentOnRhs=*/true,
failureDetail);
}
else if (auto mulOp = dyn_cast<ONNXMulOp>(user)) {
step = classifyDistributedBinaryConsumer(
user, currentValue, mulOp.getA(), mulOp.getB(), DistributedTensorOpKind::Mul, /*allowFragmentOnRhs=*/true,
failureDetail);
}
else if (auto divOp = dyn_cast<ONNXDivOp>(user)) {
step = classifyDistributedBinaryConsumer(
user, currentValue, divOp.getA(), divOp.getB(), DistributedTensorOpKind::Div, /*allowFragmentOnRhs=*/false,
failureDetail);
if (step) {
auto denseAttr = dyn_cast<DenseFPElementsAttr>(step->constantAttr);
if (!denseAttr) {
failureDetail = "div requires floating-point splat constant";
step.reset();
}
}
}
else if (auto nextConv = dyn_cast<ONNXConvOp>(user)) {
failureDetail = "onnx.Conv distributed consumer blocked by dim0-only whole-batch materialization in MergeComputeNodes";
}
else {
failureDetail = (user->getName().getStringRef() + " is not distributed-aware yet").str();
}
if (!step) {
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
analysis.barrierDetail = failureDetail;
return analysis;
}
analysis.replacementOp = user;
analysis.steps.push_back(*step);
currentValue = user->getResult(0);
}
}
static void rewriteDistributedConvReport(const DistributedConvReportTotals& totals) {
std::fstream reportFile = openReportFile("conv_distributed_consumption_report");
if (!reportFile.is_open())
return;
reportFile << "# PIM Conv Distributed Consumption Report\n\n";
reportFile << "Totals:\n";
reportFile << "- convs_seen: " << totals.totalConvs << "\n";
reportFile << "- distributed_tensors_created: " << totals.distributedTensorsCreated << "\n";
reportFile << "- distributed_values_propagated: " << totals.distributedValuesPropagated << "\n";
reportFile << "- distributed_consumers_handled_locally: " << totals.distributedConsumersHandled << "\n";
reportFile << "- distributed_conv_inputs_seen: " << totals.distributedConvInputsSeen << "\n";
reportFile << "- distributed_conv_inputs_consumed: " << totals.distributedConvInputsConsumed << "\n";
reportFile << "- materialization_barriers_inserted: " << totals.materializationBarriersInserted << "\n\n";
if (!totals.barrierReasons.empty()) {
reportFile << "Materialization barriers:\n";
for (const auto& [reason, count] : totals.barrierReasons)
reportFile << "- " << reason << ": " << count << "\n";
reportFile << "\n";
}
if (!totals.fallbackReasons.empty()) {
reportFile << "Fallback / no-distribution reasons:\n";
for (const auto& [reason, count] : totals.fallbackReasons)
reportFile << "- " << reason << ": " << count << "\n";
reportFile << "\n";
}
if (!totals.chains.empty()) {
reportFile << "Chains:\n";
for (const DistributedChainReportEntry& chain : totals.chains) {
reportFile << "- chain_id: " << chain.chainId << "\n";
reportFile << " chain_length: " << chain.chainLength << "\n";
reportFile << " producer_kind: " << chain.producerKind << "\n";
reportFile << " distributed_ops: " << chain.distributedOps << "\n";
reportFile << " materialization_points: " << chain.materializationPoints << "\n";
reportFile << " first_materialization_reason: " << chain.firstMaterializationReason << "\n";
reportFile << " max_live_fragments: " << chain.maxLiveFragments << "\n";
reportFile << " max_fragment_fanout: " << chain.maxFragmentFanout << "\n";
reportFile << " patch_builder_core_count: " << chain.patchBuilderCoreCount << "\n";
reportFile << " central_junction_detected: " << chain.centralJunctionDetected << "\n";
reportFile << " conv_input_materialization_kind: " << chain.convInputMaterializationKind << "\n";
reportFile << " local_patch_fragments: " << chain.localPatchFragments << "\n";
reportFile << " remote_patch_fragments: " << chain.remotePatchFragments << "\n";
reportFile << " halo_transfer_count: " << chain.haloTransferCount << "\n";
reportFile << " grouped_transfer_count: " << chain.groupedTransferCount << "\n";
if (!chain.fallbackReason.empty())
reportFile << " fallback_reason: " << chain.fallbackReason << "\n";
}
}
}
[[maybe_unused]] static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) {
static std::mutex reportMutex;
static DistributedConvReportTotals totals;
std::string barrierKey = stringifyDistributedConvBarrierKind(analysis.barrierKind).str();
if (!analysis.barrierDetail.empty())
barrierKey += ": " + analysis.barrierDetail;
std::lock_guard<std::mutex> guard(reportMutex);
totals.totalConvs++;
if (analysis.hasLocalConsumers()) {
totals.distributedTensorsCreated++;
totals.distributedValuesPropagated += analysis.steps.size();
totals.distributedConsumersHandled += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
return step.kind != DistributedTensorOpKind::Conv;
});
totals.distributedConvInputsSeen += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
return step.kind == DistributedTensorOpKind::Conv;
});
totals.distributedConvInputsConsumed += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
return step.kind == DistributedTensorOpKind::Conv;
});
totals.materializationBarriersInserted++;
totals.barrierReasons[barrierKey]++;
}
else {
totals.fallbackReasons[barrierKey]++;
}
DistributedChainReportEntry chain;
chain.chainId = totals.totalConvs;
chain.chainLength = analysis.steps.size() + 1;
chain.producerKind = "Conv";
chain.materializationPoints = stringifyDistributedConvBarrierKind(analysis.barrierKind).str();
chain.firstMaterializationReason = analysis.barrierDetail;
chain.maxFragmentFanout = analysis.barrierKind == DistributedConvBarrierKind::Fanout ? 2 : 1;
std::string ops;
for (size_t index = 0; index < analysis.steps.size(); ++index) {
if (!ops.empty())
ops += ", ";
ops += stringifyDistributedTensorOpKind(analysis.steps[index].kind);
}
chain.distributedOps = ops;
if (analysis.hasDistributedConvConsumer()) {
chain.convInputMaterializationKind = "distributed_with_halo_exchange";
chain.patchBuilderCoreCount = 1;
}
if (totals.chains.size() == 16)
totals.chains.erase(totals.chains.begin());
totals.chains.push_back(std::move(chain));
rewriteDistributedConvReport(totals);
}
static std::string makeDivider(ArrayRef<size_t> widths) {
std::string divider = "+";
for (size_t width : widths) {
divider.append(width + 2, '-');
divider.push_back('+');
}
return divider;
}
static void printConvReportLegend(std::fstream& reportFile) {
reportFile << "# PIM Conv Lowering Report\n\n";
reportFile << "Legend:\n";
reportFile << "- `id`: sequential Conv report entry index within this compiler invocation.\n";
reportFile << "- `where`: summarized MLIR location of the Conv op.\n";
reportFile << "- `mode`: whether the selected strategy came from `auto` policy or a forced compiler option.\n";
reportFile << "- `strategy`: selected Conv lowering algorithm.\n";
reportFile << "- `input`, `weight`, `output`: tensor shapes of the Conv operands/result.\n";
reportFile << "- `groups`: ONNX Conv group count.\n";
reportFile << "- `K`: logical reduction size, `CinPerGroup * Kh * Kw`.\n";
reportFile << "- `C`: logical output-channel width handled by the selected strategy.\n";
reportFile << "- `P`: total output positions, `N * Hout * Wout`.\n";
reportFile << "- `X`: crossbar size.\n";
reportFile << "- `pack`: packed spatial positions per MVM group, `floor(X / max(K, C))`.\n";
reportFile << "- `im2col`: total explicit im2col element count, `P * K`.\n";
reportFile << "- `im2col_budget`: maximum im2col element budget allowed by the compiler option.\n";
reportFile << "- `stream_chunk`: output positions materialized per streamed chunk, or `-` when not applicable.\n";
reportFile << "- `batch_size`: logical batch size passed into the compute lowering for this Conv form.\n";
reportFile << "- `batches`: number of repeated compute batches emitted for the Conv.\n";
reportFile << "- `spatial_compute_batch`: whether ONNX-to-Spatial used `spat.compute_batch` for this Conv.\n";
reportFile << "- `batched_instruction_emission`: whether the lowering is expected to reach batched PIM emission.\n";
reportFile << "- `reason`: strategy-selection reason.\n";
reportFile << "- Per-conv details below the table include profitability estimates and materialization diagnostics.\n";
reportFile << "- placeholders like `[7]`: value too long for the table cell; see the appendix at the end.\n\n";
}
struct ConvReportOverflow {
uint64_t placeholderId;
SmallVector<uint64_t, 4> rowIds;
std::string column;
std::string value;
};
static StringRef describeConvLoweringStrategy(PimConvLoweringType strategy) {
switch (strategy) {
case PimConvLoweringAuto: return "Automatic policy selection.";
case PimConvLoweringLegacy: return "Legacy Conv lowering path kept for compatibility.";
case PimConvLoweringDepthwise: return "Specialized depthwise lowering that avoids generic cross-channel GEMM mixing.";
case PimConvLoweringPackedIm2Col: return "Explicit im2col plus packed GEMM lowering for Conv shapes that fit well in one crossbar.";
case PimConvLoweringStreamedPatch: return "Chunked per-patch streaming Conv lowering without global im2col materialization.";
case PimConvLoweringStreamedPacked: return "Chunked streamed Conv lowering that still packs multiple output positions per MVM group.";
case PimConvLoweringOutputChannelTiled: return "Conv lowering that splits output channels across tiles when C exceeds one crossbar width.";
case PimConvLoweringInputKTiled: return "Conv lowering that splits the reduction dimension K across tiles and accumulates partial sums.";
case PimConvLoweringTiled2D: return "Conv lowering that tiles both K and output channels because neither dimension fits one crossbar.";
}
llvm_unreachable("unknown conv lowering strategy");
}
static std::string fitConvReportCell(StringRef text,
size_t width,
uint64_t rowId,
StringRef column,
std::vector<ConvReportOverflow>& overflows,
uint64_t& nextPlaceholderId,
bool rightAlign = false) {
if (text.size() <= width)
return alignCell(text, width, rightAlign);
for (ConvReportOverflow& overflow : overflows) {
if (overflow.column == column && overflow.value == text) {
if (llvm::find(overflow.rowIds, rowId) == overflow.rowIds.end())
overflow.rowIds.push_back(rowId);
std::string placeholder = "[" + std::to_string(overflow.placeholderId) + "]";
return alignCell(placeholder, width, rightAlign);
}
}
std::string placeholder = "[" + std::to_string(nextPlaceholderId++) + "]";
overflows.push_back({nextPlaceholderId - 1, {rowId}, column.str(), text.str()});
return alignCell(placeholder, width, rightAlign);
}
static void writeConvReportTable(std::fstream& reportFile, ArrayRef<ConvReportEntry> entries) {
static constexpr size_t kIdWidth = 4;
static constexpr size_t kWhereWidth = 24;
static constexpr size_t kModeWidth = 6;
static constexpr size_t kStrategyWidth = 20;
static constexpr size_t kShapeWidth = 14;
static constexpr size_t kGroupsWidth = 3;
static constexpr size_t kSmallWidth = 5;
static constexpr size_t kPWidth = 10;
static constexpr size_t kIm2colWidth = 10;
static constexpr size_t kChunkWidth = 8;
static constexpr size_t kFlagWidth = 3;
static constexpr size_t kReasonWidth = 24;
const SmallVector<size_t, 21> widths = {
kIdWidth, kWhereWidth, kModeWidth, kStrategyWidth, kShapeWidth, kShapeWidth, kShapeWidth, kGroupsWidth,
kSmallWidth, kSmallWidth, kPWidth, kSmallWidth, kSmallWidth, kIm2colWidth, kIm2colWidth,
kChunkWidth, kSmallWidth, kSmallWidth, kFlagWidth, kFlagWidth, kReasonWidth,
};
const std::string divider = makeDivider(widths);
std::vector<ConvReportOverflow> overflows;
uint64_t nextPlaceholderId = 1;
auto printRow = [&](ArrayRef<std::string> cells) {
reportFile << "|";
for (size_t i = 0; i < cells.size(); ++i)
reportFile << " " << cells[i] << " |";
reportFile << "\n";
};
reportFile << divider << "\n";
printRow({
alignCell("id", kIdWidth, true),
alignCell("where", kWhereWidth),
alignCell("mode", kModeWidth),
alignCell("strategy", kStrategyWidth),
alignCell("input", kShapeWidth),
alignCell("weight", kShapeWidth),
alignCell("output", kShapeWidth),
alignCell("grp", kGroupsWidth, true),
alignCell("K", kSmallWidth, true),
alignCell("C", kSmallWidth, true),
alignCell("P", kPWidth, true),
alignCell("X", kSmallWidth, true),
alignCell("pack", kSmallWidth, true),
alignCell("im2col", kIm2colWidth, true),
alignCell("budget", kIm2colWidth, true),
alignCell("chunk", kChunkWidth, true),
alignCell("batch", kSmallWidth, true),
alignCell("nbat", kSmallWidth, true),
alignCell("scb", kFlagWidth),
alignCell("bie", kFlagWidth),
alignCell("reason", kReasonWidth),
});
reportFile << divider << "\n";
for (const ConvReportEntry& entry : entries) {
printRow({
alignCell(std::to_string(entry.id), kIdWidth, true),
fitConvReportCell(entry.where, kWhereWidth, entry.id, "where", overflows, nextPlaceholderId),
alignCell(entry.mode, kModeWidth),
fitConvReportCell(entry.strategy, kStrategyWidth, entry.id, "strategy", overflows, nextPlaceholderId),
fitConvReportCell(entry.inputShape, kShapeWidth, entry.id, "input", overflows, nextPlaceholderId),
fitConvReportCell(entry.weightShape, kShapeWidth, entry.id, "weight", overflows, nextPlaceholderId),
fitConvReportCell(entry.outputShape, kShapeWidth, entry.id, "output", overflows, nextPlaceholderId),
alignCell(std::to_string(entry.groups), kGroupsWidth, true),
alignCell(std::to_string(entry.k), kSmallWidth, true),
alignCell(std::to_string(entry.c), kSmallWidth, true),
alignCell(std::to_string(entry.p), kPWidth, true),
alignCell(std::to_string(entry.xbarSize), kSmallWidth, true),
alignCell(std::to_string(entry.pack), kSmallWidth, true),
alignCell(std::to_string(entry.im2colElements), kIm2colWidth, true),
alignCell(std::to_string(entry.im2colBudget), kIm2colWidth, true),
fitConvReportCell(entry.chunkText, kChunkWidth, entry.id, "stream_chunk", overflows, nextPlaceholderId, true),
alignCell(std::to_string(entry.batchSize), kSmallWidth, true),
alignCell(std::to_string(entry.numberOfBatches), kSmallWidth, true),
alignCell(entry.spatialComputeBatch, kFlagWidth),
alignCell(entry.batchedInstructionEmission, kFlagWidth),
fitConvReportCell(entry.reason, kReasonWidth, entry.id, "reason", overflows, nextPlaceholderId),
});
}
reportFile << divider << "\n";
if (overflows.empty())
reportFile << "\n";
else {
reportFile << "\nAppendix:\n";
for (const ConvReportOverflow& overflow : overflows) {
reportFile << " [" << overflow.placeholderId << "] rows ";
for (size_t i = 0; i < overflow.rowIds.size(); ++i) {
if (i != 0)
reportFile << ", ";
reportFile << overflow.rowIds[i];
}
reportFile << ", " << overflow.column << ": " << overflow.value << "\n";
}
reportFile << "\n";
}
reportFile << "Per-Conv Details:\n";
for (const ConvReportEntry& entry : entries) {
reportFile << "- Conv " << entry.id << ": mode=" << entry.mode << ", strategy=" << entry.strategy
<< ", reason=" << entry.reason << "\n";
reportFile << " K=" << entry.k << ", Cout=" << entry.c << ", output_positions=" << entry.p
<< ", estimated_mvm_count=" << entry.estimatedMvmCount
<< ", estimated_reduction_vadd_count=" << entry.estimatedReductionVAddCount
<< ", estimated_output_fragments=" << entry.estimatedOutputFragments << "\n";
reportFile << " materialization_required_at_func_return=" << entry.materializationRequiredAtReturn
<< ", materialization_kind=" << entry.materializationKind
<< ", giant_collector_concat_expected=" << entry.giantCollectorConcatExpected
<< ", concat_operand_count=" << entry.concatOperandCount
<< ", collector_core=" << entry.collectorCore
<< ", full_input_broadcast_expected=" << entry.fullInputBroadcastExpected << "\n";
if (!entry.rejectedAutoStrategy.empty())
reportFile << " rejected_auto_strategy=" << entry.rejectedAutoStrategy << "\n";
if (!entry.fallbackReason.empty())
reportFile << " fallback_reason=" << entry.fallbackReason << "\n";
}
reportFile << "\n";
llvm::SmallVector<PimConvLoweringType, 8> usedStrategies;
for (const ConvReportEntry& entry : entries) {
PimConvLoweringType strategy = PimConvLoweringAuto;
for (PimConvLoweringType candidate : {
PimConvLoweringAuto,
PimConvLoweringLegacy,
PimConvLoweringDepthwise,
PimConvLoweringPackedIm2Col,
PimConvLoweringStreamedPatch,
PimConvLoweringStreamedPacked,
PimConvLoweringOutputChannelTiled,
PimConvLoweringInputKTiled,
PimConvLoweringTiled2D,
}) {
if (entry.strategy == stringifyConvLoweringStrategy(candidate)) {
strategy = candidate;
break;
}
}
if (llvm::find(usedStrategies, strategy) == usedStrategies.end())
usedStrategies.push_back(strategy);
}
reportFile << "Strategies used in this report:\n";
for (PimConvLoweringType strategy : usedStrategies)
reportFile << "- `" << stringifyConvLoweringStrategy(strategy).str() << "`: "
<< describeConvLoweringStrategy(strategy).str() << "\n";
}
static void rewriteConvLoweringReport(ArrayRef<ConvReportEntry> entries) {
std::fstream reportFile = openReportFile("conv_lowering_report");
if (!reportFile.is_open())
return;
printConvReportLegend(reportFile);
writeConvReportTable(reportFile, entries);
}
[[maybe_unused]] static FailureOr<PimConvLoweringType> resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) {
if (!useExperimentalConvImpl)
return pimConvLowering.getValue();
if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) {
convOp.emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering="
<< stringifyConvLoweringStrategy(pimConvLowering);
return failure();
}
return PimConvLoweringPackedIm2Col;
}
static ConvLoweringDecision chooseConvLoweringStrategy(const ConvGeometry& geo,
PimConvLoweringType requested,
const DistributedConvAnalysis& analysis) {
if (requested != PimConvLoweringAuto)
return {requested, "forced by compiler option", /*isAuto=*/false, "", ""};
// Transform-based convolution is intentionally not selected for this ISA:
// it would require explicit transform sequences and staging traffic on top of
// the same crossbar MVM primitive, which is not attractive here.
if (geo.isDepthwise)
return {PimConvLoweringDepthwise, "depthwise convolution", /*isAuto=*/true, "", ""};
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
return {PimConvLoweringPackedIm2Col,
"fits crossbar, packing useful, and global im2col fits budget",
/*isAuto=*/true,
"",
""};
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements > pimConvIm2colMaxElements)
return {PimConvLoweringStreamedPacked,
"fits crossbar and packing useful, but global im2col exceeds budget",
/*isAuto=*/true,
"",
""};
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
return {PimConvLoweringStreamedPatch, "fits crossbar but packing is not useful", /*isAuto=*/true, "", ""};
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
return {PimConvLoweringOutputChannelTiled,
"output channels exceed one crossbar width",
/*isAuto=*/true,
"",
""};
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) {
ConvStrategyEstimate estimate = estimateConvStrategy(geo, PimConvLoweringInputKTiled, analysis);
std::string fallbackReason = "auto rejects input-k-tiled because the reduction-heavy path is force-only for now";
if (estimate.requiresFuncReturnMaterialization && estimate.perOutputPositionReduction
&& estimate.giantCollectorConcatExpected) {
fallbackReason += "; func.return would materialize " + std::to_string(estimate.concatOperandCount)
+ " output fragments through a single collector concat after per-position reductions";
}
if (estimate.fullInputBroadcastExpected)
fallbackReason += "; the current lowering also broadcasts the padded input to many workers";
return {PimConvLoweringLegacy,
"fall back to legacy explicit-im2col for the current auto policy",
/*isAuto=*/true,
fallbackReason,
stringifyConvLoweringStrategy(PimConvLoweringInputKTiled).str()};
}
return {PimConvLoweringTiled2D, "both reduction K and output channels exceed one crossbar", /*isAuto=*/true, "", ""};
}
[[maybe_unused]] static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp,
const ConvGeometry& geo,
PimConvLoweringType strategy) {
switch (strategy) {
case PimConvLoweringAuto:
case PimConvLoweringLegacy:
return success();
case PimConvLoweringDepthwise:
if (geo.isDepthwise)
return success();
return convOp.emitOpError("forced depthwise Conv lowering requires a depthwise convolution");
case PimConvLoweringPackedIm2Col:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
return success();
return convOp.emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget");
case PimConvLoweringStreamedPatch:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
return success();
return convOp.emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar");
case PimConvLoweringStreamedPacked:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2)
return success();
return convOp.emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2");
case PimConvLoweringOutputChannelTiled:
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
return success();
return convOp.emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X");
case PimConvLoweringInputKTiled:
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize)
return success();
return convOp.emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X");
case PimConvLoweringTiled2D:
if (geo.k > geo.xbarSize && geo.c > geo.xbarSize)
return success();
return convOp.emitOpError("forced tiled-2d Conv lowering requires K > X and C > X");
}
llvm_unreachable("unknown conv lowering strategy");
}
static void reportConvLoweringDecision(ONNXConvOp convOp,
const ConvGeometry& geo,
const ConvLoweringDecision& decision,
const ConvStrategyEstimate& estimate,
int64_t batchSize,
int64_t numberOfBatches,
bool usesComputeBatch,
bool usesBatchedInstructionEmission,
std::optional<uint64_t> streamChunkPositions = std::nullopt) {
if (!pimReportConvLowering)
return;
const std::string location = summarizeLocation(convOp.getLoc());
const std::string strategy = stringifyConvLoweringStrategy(decision.strategy).str();
const std::string mode = decision.isAuto ? "auto" : "forced";
const std::string inputShape = formatShape(cast<RankedTensorType>(convOp.getX().getType()).getShape());
const std::string weightShape = formatShape(cast<RankedTensorType>(convOp.getW().getType()).getShape());
const std::string outputShape = formatShape(cast<RankedTensorType>(convOp.getY().getType()).getShape());
const std::string chunkText = streamChunkPositions ? std::to_string(*streamChunkPositions) : "-";
const std::string scbText = usesComputeBatch ? "yes" : "no";
const std::string bieText = usesBatchedInstructionEmission ? "yes" : "no";
static uint64_t reportIndex = 0;
const uint64_t currentIndex = ++reportIndex;
static std::mutex reportMutex;
static std::vector<ConvReportEntry> reportEntries;
std::lock_guard<std::mutex> lock(reportMutex);
reportEntries.push_back({
currentIndex,
location,
strategy,
mode,
inputShape,
weightShape,
outputShape,
geo.group,
geo.k,
geo.c,
geo.p,
geo.xbarSize,
geo.pack,
geo.im2colElements,
pimConvIm2colMaxElements,
chunkText,
batchSize,
numberOfBatches,
scbText,
bieText,
decision.reason,
decision.fallbackReason,
decision.rejectedAutoStrategy,
estimate.estimatedMvmCount,
estimate.estimatedReductionVAddCount,
estimate.estimatedOutputFragments,
estimate.requiresFuncReturnMaterialization ? "yes" : "no",
estimate.materializationKind,
estimate.concatOperandCount,
estimate.collectorCore,
estimate.giantCollectorConcatExpected ? "yes" : "no",
estimate.fullInputBroadcastExpected ? "yes" : "no",
});
rewriteConvLoweringReport(reportEntries);
}
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
return bias;
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
return tensor::ExpandShapeOp::create(rewriter,
loc,
expandedBiasType,
bias,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
assert(value > 0 && "expected positive value");
limit = std::min(value, limit);
for (int64_t candidate = limit; candidate >= 1; --candidate)
if (value % candidate == 0)
return candidate;
return 1;
}
static Value createZeroPaddedTensor(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> lowPadValues,
ArrayRef<int64_t> highPadValues,
PatternRewriter& rewriter,
Location loc) {
auto valueType = cast<RankedTensorType>(value.getType());
if (valueType == resultType)
return value;
SmallVector<OpFoldResult> lowPads;
SmallVector<OpFoldResult> highPads;
lowPads.reserve(lowPadValues.size());
highPads.reserve(highPadValues.size());
for (auto lowPad : lowPadValues)
lowPads.push_back(rewriter.getIndexAttr(lowPad));
for (auto highPad : highPadValues)
highPads.push_back(rewriter.getIndexAttr(highPad));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t dim = 0, rank = resultType.getRank(); dim < rank; ++dim)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(resultType.getElementType()), resultType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value createConvInputPatch(Value input,
RankedTensorType patchType,
Value batchIndex,
Value channelOffset,
Value inputHeightOffset,
Value inputWidthOffset,
int64_t dilationHeight,
int64_t dilationWidth,
PatternRewriter& rewriter,
Location loc) {
const int64_t patchChannels = patchType.getDimSize(1);
const int64_t kernelHeight = patchType.getDimSize(2);
const int64_t kernelWidth = patchType.getDimSize(3);
if (dilationHeight == 1 && dilationWidth == 1) {
SmallVector<OpFoldResult> offsets {batchIndex, channelOffset, inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(patchChannels),
rewriter.getIndexAttr(kernelHeight),
rewriter.getIndexAttr(kernelWidth)};
return tensor::ExtractSliceOp::create(rewriter, loc, patchType, input, offsets, sizes, getUnitStrides(rewriter, 4));
}
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
auto elementType = patchType.getElementType();
auto pixelType = RankedTensorType::get({1, patchChannels, 1, 1}, elementType, patchType.getEncoding());
Value patch = tensor::EmptyOp::create(rewriter, loc, patchType.getShape(), elementType);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value sourceHeightOffset = affineAddConst(rewriter, loc, inputHeightOffset, kernelH * dilationHeight, anchorOp);
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value sourceWidthOffset = affineAddConst(rewriter, loc, inputWidthOffset, kernelW * dilationWidth, anchorOp);
SmallVector<OpFoldResult> sourceOffsets {batchIndex, channelOffset, sourceHeightOffset, sourceWidthOffset};
SmallVector<OpFoldResult> sourceSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(patchChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value sourcePixel = tensor::ExtractSliceOp::create(
rewriter, loc, pixelType, input, sourceOffsets, sourceSizes, getUnitStrides(rewriter, 4));
SmallVector<OpFoldResult> targetOffsets {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(kernelH), rewriter.getIndexAttr(kernelW)};
patch = tensor::InsertSliceOp::create(
rewriter, loc, sourcePixel, patch, targetOffsets, sourceSizes, getUnitStrides(rewriter, 4));
}
}
return patch;
}
static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType,
RankedTensorType gemmOutType,
RankedTensorType nhwcType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
Location loc);
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
namespace depthwise {
struct Tiling {
int64_t outputMultiplier;
int64_t kernelElements;
int64_t channelsPerTile;
int64_t tileInputRows;
int64_t tileOutputChannels;
int64_t numChannelTiles;
int64_t spatialPatchesPerBatch;
int64_t totalPatches;
};
static std::optional<Tiling> computeTiling(int64_t batchSize,
int64_t numChannelsIn,
int64_t numChannelsOut,
int64_t wHeight,
int64_t wWidth,
int64_t outHeight,
int64_t outWidth) {
const int64_t kernelElements = wHeight * wWidth;
const int64_t outputMultiplier = numChannelsOut / numChannelsIn;
const int64_t xbarDim = static_cast<int64_t>(crossbarSize.getValue());
if (kernelElements <= 0 || outputMultiplier <= 0 || kernelElements > xbarDim || outputMultiplier > xbarDim)
return std::nullopt;
const int64_t maxChannelsPerTile = std::min(xbarDim / kernelElements, xbarDim / outputMultiplier);
if (maxChannelsPerTile <= 0)
return std::nullopt;
const int64_t channelsPerTile = findLargestDivisorAtMost(numChannelsIn, maxChannelsPerTile);
const int64_t tileInputRows = channelsPerTile * kernelElements;
const int64_t tileOutputChannels = channelsPerTile * outputMultiplier;
if (tileInputRows > xbarDim || tileOutputChannels > xbarDim)
return std::nullopt;
return Tiling {
outputMultiplier,
kernelElements,
channelsPerTile,
tileInputRows,
tileOutputChannels,
numChannelsIn / channelsPerTile,
outHeight * outWidth,
batchSize * outHeight * outWidth,
};
}
static Value buildPackedWeights(DenseElementsAttr wDenseAttr,
RankedTensorType wType,
const Tiling& tiling,
PatternRewriter& rewriter,
Location loc) {
auto packedWeightType = RankedTensorType::get(
{tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType());
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
for (int64_t tileIndex = 0; tileIndex < tiling.numChannelTiles; ++tileIndex) {
const int64_t channelBase = tileIndex * tiling.channelsPerTile;
for (int64_t localChannel = 0; localChannel < tiling.channelsPerTile; ++localChannel) {
const int64_t globalChannel = channelBase + localChannel;
for (int64_t kernelIndex = 0; kernelIndex < tiling.kernelElements; ++kernelIndex) {
const int64_t kernelH = kernelIndex / wType.getDimSize(3);
const int64_t kernelW = kernelIndex % wType.getDimSize(3);
const int64_t targetRow = localChannel * tiling.kernelElements + kernelIndex;
for (int64_t multiplierIndex = 0; multiplierIndex < tiling.outputMultiplier; ++multiplierIndex) {
const int64_t globalOutChannel = globalChannel * tiling.outputMultiplier + multiplierIndex;
const int64_t sourceFlatIndex =
((globalOutChannel * wType.getDimSize(1) * wType.getDimSize(2)) + kernelH) * wType.getDimSize(3) + kernelW;
const int64_t targetCol = localChannel * tiling.outputMultiplier + multiplierIndex;
const int64_t targetFlatIndex =
((tileIndex * tiling.tileInputRows) + targetRow) * tiling.tileOutputChannels + targetCol;
packedValues[targetFlatIndex] = sourceValues[sourceFlatIndex];
}
}
}
}
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
}
static Value createPaddedInput(Value input,
RankedTensorType inputType,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
PatternRewriter& rewriter,
Location loc) {
if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0)
return input;
auto paddedInputType = RankedTensorType::get({inputType.getDimSize(0),
inputType.getDimSize(1),
inputType.getDimSize(2) + padHeightBegin + padHeightEnd,
inputType.getDimSize(3) + padWidthBegin + padWidthEnd},
inputType.getElementType());
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value padded = createZeroPaddedTensor(computeInput,
paddedInputType,
{0, 0, padHeightBegin, padWidthBegin},
{0, 0, padHeightEnd, padWidthEnd},
rewriter,
loc);
spatial::SpatYieldOp::create(rewriter, loc, padded);
});
return computeOp.getResult(0);
}
static Value createInputTile(Value input,
Value patchIndex,
Value channelTileIndex,
RankedTensorType inputTileType,
const Tiling& tiling,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
PatternRewriter& rewriter,
Location loc) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp);
Value batchPatchIndex = affineModConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp);
Value outHeightIndex = affineFloorDivConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp);
Value outWidthIndex = affineModConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp);
Value inputHeightOffset =
strideHeight == 1 ? outHeightIndex : affineMulConst(rewriter, loc, outHeightIndex, strideHeight, anchorOp);
Value inputWidthOffset =
strideWidth == 1 ? outWidthIndex : affineMulConst(rewriter, loc, outWidthIndex, strideWidth, anchorOp);
Value channelOffset = tiling.channelsPerTile == 1
? channelTileIndex
: affineMulConst(rewriter, loc, channelTileIndex, tiling.channelsPerTile, anchorOp);
Value tile4D = createConvInputPatch(input,
inputTileType,
batchIndex,
channelOffset,
inputHeightOffset,
inputWidthOffset,
dilationHeight,
dilationWidth,
rewriter,
loc);
auto collapsedType = RankedTensorType::get({1, tiling.tileInputRows}, inputTileType.getElementType());
return tensor::CollapseShapeOp::create(rewriter,
loc,
collapsedType,
tile4D,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
}
static Value createWeightTile(Value packedWeights,
Value channelTileIndex,
RankedTensorType packedWeightType,
const Tiling& tiling,
PatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tiling.tileInputRows),
rewriter.getIndexAttr(tiling.tileOutputChannels)};
auto sliceType =
RankedTensorType::get({1, tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType());
Value slice = tensor::ExtractSliceOp::create(
rewriter, loc, sliceType, packedWeights, offsets, sizes, getUnitStrides(rewriter, 3));
auto collapsedType =
RankedTensorType::get({tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType());
return tensor::CollapseShapeOp::create(rewriter,
loc,
collapsedType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value createBiasTile(
Value bias, Value channelTileIndex, const Tiling& tiling, PatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType());
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value channelOffset = tiling.tileOutputChannels == 1
? channelTileIndex
: affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp);
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), channelOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)};
return tensor::ExtractSliceOp::create(rewriter, loc, biasTileType, bias, offsets, sizes, getUnitStrides(rewriter, 2));
}
static Value insertOutputTile(Value rowTile,
Value rowAccumulator,
Value channelTileIndex,
const Tiling& tiling,
PatternRewriter& rewriter,
Location loc) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value channelOffset = tiling.tileOutputChannels == 1
? channelTileIndex
: affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp);
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), channelOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)};
return tensor::InsertSliceOp::create(
rewriter, loc, rowTile, rowAccumulator, offsets, sizes, getUnitStrides(rewriter, 2));
}
static FailureOr<Value> reconstructDepthwiseGemmRows(Value pieces,
RankedTensorType piecesType,
RankedTensorType gemmOutType,
const Tiling& tiling,
PatternRewriter& rewriter,
Location loc) {
auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) {
auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, gemmOutType.getShape(), gemmOutType.getElementType());
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, tiling.totalPatches);
Value cNumChannelTiles = getOrCreateIndexConstant(rewriter, anchorOp, tiling.numChannelTiles);
auto patchLoop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cNumPatches,
c1,
ValueRange {outputInit},
[&](OpBuilder&,
Location nestedLoc,
Value patchIndex,
ValueRange patchIterArgs,
SmallVectorImpl<Value>& patchYielded) {
Value outputAcc = patchIterArgs.front();
Value rowInit = tensor::EmptyOp::create(rewriter, nestedLoc, rowType.getShape(), rowType.getElementType());
auto tileLoop = buildNormalizedScfFor(
rewriter,
nestedLoc,
c0,
cNumChannelTiles,
c1,
ValueRange {rowInit},
[&](OpBuilder&,
Location tileLoc,
Value channelTileIndex,
ValueRange tileIterArgs,
SmallVectorImpl<Value>& tileYielded) {
Value rowAcc = tileIterArgs.front();
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr d1 = getAffineDimExpr(1, context);
Value laneIndex = createOrFoldAffineApply(
rewriter, tileLoc, (d0 * tiling.totalPatches) + d1, ValueRange {channelTileIndex, patchIndex}, anchorOp);
auto rowTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, piecesType.getElementType());
SmallVector<OpFoldResult> pieceOffsets {laneIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tiling.tileOutputChannels)};
Value rowTile = tensor::ExtractSliceOp::create(
rewriter, tileLoc, rowTileType, piecesArg, pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
Value rowNext = insertOutputTile(rowTile, rowAcc, channelTileIndex, tiling, rewriter, tileLoc);
tileYielded.push_back(rowNext);
return success();
});
if (failed(tileLoop))
return failure();
SmallVector<OpFoldResult> rowOffsets {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(gemmOutType.getDimSize(1))};
Value outputNext = tensor::InsertSliceOp::create(rewriter,
nestedLoc,
tileLoop->results.front(),
outputAcc,
rowOffsets,
rowSizes,
getUnitStrides(rewriter, 2))
.getResult();
patchYielded.push_back(outputNext);
return success();
});
if (failed(patchLoop))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, patchLoop->results.front());
return success();
});
if (failed(collectedOp))
return failure();
return collectedOp->getResult(0);
}
static bool canUseStructuredRewrite(const ConvLoweringState& state) {
if (!getHostConstDenseElementsAttr(state.w))
return false;
auto tiling = computeTiling(state.batchSize,
state.numChannelsIn,
state.numChannelsOut,
state.wHeight,
state.wWidth,
state.outHeight,
state.outWidth);
if (!tiling)
return false;
if (!state.hasBias)
return true;
auto biasType = dyn_cast<RankedTensorType>(state.b.getType());
if (!biasType)
return false;
if (biasType.getRank() == 1)
return biasType.getDimSize(0) == state.numChannelsOut;
if (biasType.getRank() != 2)
return false;
return biasType.getDimSize(0) == 1 && biasType.getDimSize(1) == state.numChannelsOut;
}
static FailureOr<Value>
rewriteConv(Operation* convOp, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) {
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
if (!wDenseAttr) {
convOp->emitOpError("requires constant-derived weights for structured depthwise Spatial lowering");
return failure();
}
auto tiling = computeTiling(state.xType.getDimSize(0),
state.xType.getDimSize(1),
state.outType.getDimSize(1),
state.wType.getDimSize(2),
state.wType.getDimSize(3),
state.outType.getDimSize(2),
state.outType.getDimSize(3));
if (!tiling) {
convOp->emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering");
return failure();
}
Value paddedInput = createPaddedInput(state.x,
state.xType,
state.padHeightBegin,
state.padHeightEnd,
state.padWidthBegin,
state.padWidthEnd,
rewriter,
loc);
Value packedWeights = buildPackedWeights(wDenseAttr, state.wType, *tiling, rewriter, loc);
Value expandedBias;
SmallVector<Value> batchInputs {paddedInput};
if (state.hasBias) {
expandedBias = expandBiasIfNeeded(state.b, rewriter, loc);
auto biasType = dyn_cast<RankedTensorType>(expandedBias.getType());
if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1
|| biasType.getDimSize(1) != state.outType.getDimSize(1)) {
convOp->emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering");
return failure();
}
batchInputs.push_back(expandedBias);
}
auto gemmOutType =
RankedTensorType::get({tiling->totalPatches, state.outType.getDimSize(1)}, state.outType.getElementType());
auto piecesType = RankedTensorType::get({tiling->totalPatches * tiling->numChannelTiles, tiling->tileOutputChannels},
state.outType.getElementType());
auto paddedInputType = cast<RankedTensorType>(paddedInput.getType());
auto inputTileType =
RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)},
paddedInputType.getElementType());
SmallVector<Value> batchWeights;
if (tiling->numChannelTiles == 1) {
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
batchWeights.push_back(createWeightTile(packedWeights,
c0,
cast<RankedTensorType>(packedWeights.getType()),
*tiling,
rewriter,
loc));
}
else {
batchWeights.push_back(packedWeights);
}
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {piecesType},
tiling->totalPatches * tiling->numChannelTiles,
batchWeights,
batchInputs,
[&](detail::SpatComputeBatchBodyArgs args) {
auto pickInputByRank = [&](int64_t rank) -> Value {
for (Value input : args.inputs) {
auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (inputType && inputType.getRank() == rank)
return input;
}
return Value();
};
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value patchIndex = tiling->numChannelTiles == 1
? args.lane
: affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp);
Value channelTileIndex = tiling->numChannelTiles == 1
? getOrCreateIndexConstant(rewriter, anchorOp, 0)
: affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp);
Value paddedInputArg = pickInputByRank(/*rank=*/4);
if (!paddedInputArg) {
convOp->emitOpError("structured depthwise batch body requires a rank-4 padded input block argument");
return failure();
}
Value inputTile = createInputTile(paddedInputArg,
patchIndex,
channelTileIndex,
inputTileType,
*tiling,
state.strideHeight,
state.strideWidth,
state.dilationHeight,
state.dilationWidth,
state.outType.getDimSize(3),
rewriter,
loc);
Value weightTile = tiling->numChannelTiles == 1
? args.weights.front()
: createWeightTile(args.weights.front(),
channelTileIndex,
cast<RankedTensorType>(args.weights.front().getType()),
*tiling,
rewriter,
loc);
auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType());
Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult();
if (args.inputs.size() > 1) {
Value biasArg = pickInputByRank(/*rank=*/2);
if (!biasArg) {
convOp->emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present");
return failure();
}
Value biasTile = tiling->numChannelTiles == 1 ? biasArg : createBiasTile(biasArg, channelTileIndex, *tiling, rewriter, loc);
rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult();
}
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tiling->tileOutputChannels)};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
return success();
});
if (failed(batchOp))
return failure();
auto nhwcType = RankedTensorType::get(
{state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)},
state.outType.getElementType());
Value collectedRows = batchOp->getResult(0);
if (tiling->numChannelTiles != 1) {
auto reconstructedRows =
reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc);
if (failed(reconstructedRows))
return failure();
collectedRows = *reconstructedRows;
}
return createCollectedConvOutput(ValueRange {collectedRows},
state.outType,
gemmOutType,
nhwcType,
state.outType,
tiling->totalPatches,
state.outType.getDimSize(1),
/*packFactor=*/1,
{},
rewriter,
loc);
}
} // namespace depthwise
namespace standard {
struct ConvGemmPlan {
int64_t patchSize;
int64_t numPatchesPerBatch;
int64_t globalNumPatches;
int64_t chunkStart;
int64_t chunkNumPatches;
int64_t maxParallelPixels;
int64_t effectiveMaxParallelPixels;
int64_t packedNumRows;
RankedTensorType im2colType;
RankedTensorType im2colRowType;
RankedTensorType gemmInputRowsType;
RankedTensorType wFlatType;
RankedTensorType wTransType;
RankedTensorType gemmOutType;
RankedTensorType gemmOutputRowsType;
RankedTensorType nhwcType;
};
static ConvGemmPlan
buildConvGemmPlan(const ConvLoweringState& state,
bool canPackWeightsAsConstants,
bool canPackBiasAsConstants,
int64_t chunkStart,
int64_t chunkNumPatches,
std::optional<int64_t> forcedPackFactor = std::nullopt);
static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state,
PatternRewriter& rewriter,
Location loc) {
if (state.padHeightBegin == 0 && state.padHeightEnd == 0 && state.padWidthBegin == 0 && state.padWidthEnd == 0)
return {state.x, state.xType};
auto paddedType = RankedTensorType::get({state.batchSize,
state.numChannelsIn,
state.xHeight + state.padHeightBegin + state.padHeightEnd,
state.xWidth + state.padWidthBegin + state.padWidthEnd},
state.xType.getElementType());
auto paddedInputOp =
createSpatCompute<1>(rewriter, loc, TypeRange {paddedType}, {}, state.x, [&](Value inputArg) {
Value paddedInput = createZeroPaddedTensor(inputArg,
paddedType,
{0, 0, state.padHeightBegin, state.padWidthBegin},
{0, 0, state.padHeightEnd, state.padWidthEnd},
rewriter,
loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return {paddedInputOp.getResult(0), paddedType};
}
static Value createPaddedRows(Value rows,
RankedTensorType rowsType,
int64_t paddedRows,
PatternRewriter& rewriter,
Location loc) {
if (rowsType.getDimSize(0) == paddedRows)
return rows;
auto paddedType =
RankedTensorType::get({paddedRows, rowsType.getDimSize(1)}, rowsType.getElementType(), rowsType.getEncoding());
return createZeroPaddedTensor(
rows, paddedType, {0, 0}, {paddedRows - rowsType.getDimSize(0), 0}, rewriter, loc);
}
static Value packRowsForParallelGemm(
Value rows, RankedTensorType rowsType, int64_t packFactor, PatternRewriter& rewriter, Location loc) {
if (packFactor == 1)
return rows;
const int64_t paddedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor) * packFactor;
const int64_t packedNumRows = paddedNumRows / packFactor;
const int64_t rowWidth = rowsType.getDimSize(1);
auto groupedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
auto packedType =
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
Value padded = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
Value grouped = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
padded,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
return tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
grouped,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
static Value unpackRowsFromParallelGemm(Value packedRows,
RankedTensorType packedRowsType,
int64_t unpackedRows,
int64_t rowWidth,
int64_t packFactor,
PatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedRows;
const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get(
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType =
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedRows,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value padded = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expanded,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
if (paddedNumRows == unpackedRows)
return padded;
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, padded, offsets, sizes, getUnitStrides(rewriter, 2));
}
static Value createWeightMatrix(
Value weights, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value flattened = tensor::CollapseShapeOp::create(rewriter,
loc,
plan.wFlatType,
weight,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
return ONNXTransposeOp::create(rewriter, loc, plan.wTransType, flattened, rewriter.getI64ArrayAttr({1, 0}))
.getResult();
};
if (isCompileTimeComputable(weights))
return buildWeightMatrix(weights);
auto computeOp =
createSpatCompute<1>(rewriter, loc, TypeRange {plan.wTransType}, {}, ValueRange {weights}, [&](Value weight) {
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
});
return computeOp.getResult(0);
}
static Value createPaddedConvMatrix(Value matrix,
RankedTensorType sourceType,
RankedTensorType paddedType,
PatternRewriter& rewriter,
Location loc) {
if (sourceType == paddedType)
return matrix;
return createZeroPaddedTensor(matrix,
paddedType,
{0, 0},
{paddedType.getDimSize(0) - sourceType.getDimSize(0),
paddedType.getDimSize(1) - sourceType.getDimSize(1)},
rewriter,
loc);
}
static Value createPaddedConstantMatrix(DenseElementsAttr sourceAttr,
RankedTensorType sourceType,
RankedTensorType paddedType,
PatternRewriter& rewriter) {
SmallVector<Attribute> paddedValues(
paddedType.getNumElements(), cast<Attribute>(rewriter.getZeroAttr(paddedType.getElementType())));
SmallVector<Attribute> sourceValues(sourceAttr.getValues<Attribute>());
const int64_t sourceRows = sourceType.getDimSize(0);
const int64_t sourceCols = sourceType.getDimSize(1);
const int64_t paddedCols = paddedType.getDimSize(1);
for (int64_t row = 0; row < sourceRows; ++row)
for (int64_t col = 0; col < sourceCols; ++col)
paddedValues[row * paddedCols + col] = sourceValues[row * sourceCols + col];
auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType);
}
static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr,
const ConvLoweringState& state,
int64_t paddedK,
int64_t paddedC,
PatternRewriter& rewriter) {
auto paddedType = RankedTensorType::get({paddedK, paddedC}, state.wType.getElementType());
SmallVector<Attribute> sourceValues(sourceAttr.getValues<Attribute>());
SmallVector<Attribute> paddedValues(
paddedType.getNumElements(), cast<Attribute>(rewriter.getZeroAttr(paddedType.getElementType())));
for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) {
for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) {
for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) {
for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) {
const int64_t sourceFlatIndex =
(((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW;
const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW;
paddedValues[patchIndex * paddedC + outChannel] = sourceValues[sourceFlatIndex];
}
}
}
}
auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType);
}
static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
Location loc) {
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
ConvGeometry geo = buildConvGeometry(state);
const int64_t xbarDim = geo.xbarSize;
const int64_t numKSlices = ceilIntegerDivide(geo.k, xbarDim);
const int64_t paddedK = numKSlices * xbarDim;
const uint64_t maxLanesPerBatch =
std::max<uint64_t>(1,
static_cast<uint64_t>(crossbarCountInCore.getValue())
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
const uint64_t rowChunkWidth = std::max<uint64_t>(
1,
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
maxLanesPerBatch,
static_cast<uint64_t>(state.outWidth)}));
const auto elementType = state.outType.getElementType();
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
if (!wDenseAttr)
return failure();
Value paddedWeight = createPaddedInputKTiledWeightConstant(wDenseAttr, state, paddedK, xbarDim, rewriter);
Value paddedBias;
RankedTensorType paddedBiasType;
if (state.hasBias) {
Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
auto biasMatrixType = cast<RankedTensorType>(biasMatrix.getType());
paddedBiasType = RankedTensorType::get({1, xbarDim}, elementType);
if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b))
paddedBias = createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter);
else
paddedBias = materializeOrComputeUnary(
biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) {
return createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc);
});
}
SmallVector<Value> chunkRows;
const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth;
chunkRows.reserve(
state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast<int64_t>(rowChunkWidth)));
for (int64_t batchIndex = 0; batchIndex < state.batchSize; ++batchIndex) {
for (int64_t outHeightIndex = 0; outHeightIndex < state.outHeight; ++outHeightIndex) {
for (int64_t outWidthChunkStart = 0; outWidthChunkStart < state.outWidth;
outWidthChunkStart += static_cast<int64_t>(rowChunkWidth)) {
const int64_t chunkNumPatches =
std::min<int64_t>(static_cast<int64_t>(rowChunkWidth), state.outWidth - outWidthChunkStart);
auto chunkRowsType = RankedTensorType::get({chunkNumPatches, state.numChannelsOut}, elementType);
auto paddedRowType = RankedTensorType::get({1, xbarDim}, elementType);
auto paddedChunkRowType = RankedTensorType::get({1, paddedK}, elementType);
auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elementType);
auto collapsedPatchType = RankedTensorType::get({1, geo.k}, elementType);
auto weightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType());
auto rowType = RankedTensorType::get({1, state.numChannelsOut}, elementType);
SmallVector<Value> inputsStorage {preparedInput.value};
if (state.hasBias)
inputsStorage.push_back(paddedBias);
ValueRange inputs(inputsStorage);
auto chunkCompute = spatial::SpatCompute::create(rewriter, loc, TypeRange {chunkRowsType}, ValueRange {paddedWeight}, inputs);
auto* block = new Block();
block->addArgument(paddedWeight.getType(), loc);
for (Value input : inputs)
block->addArgument(input.getType(), loc);
chunkCompute.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
auto buildChunk = [&]() -> LogicalResult {
Value weightArg = block->getArgument(0);
Value inputArg = block->getArgument(1);
Value biasArg = state.hasBias ? block->getArgument(2) : Value();
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value cBatchIndex = getOrCreateIndexConstant(rewriter, anchorOp, batchIndex);
Value cZero = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value cKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
Value cOne = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
Value cInputHeightOffset =
getOrCreateIndexConstant(rewriter, anchorOp, outHeightIndex * state.strideHeight);
Value chunkRowsValue = tensor::EmptyOp::create(rewriter, loc, chunkRowsType.getShape(), elementType);
auto widthLoop = buildNormalizedScfFor(
rewriter,
loc,
cZero,
getOrCreateIndexConstant(rewriter, anchorOp, chunkNumPatches),
cOne,
ValueRange {chunkRowsValue},
[&](OpBuilder&, Location nestedLoc, Value widthIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value laneWithChunkOffset = affineAddConst(rewriter, nestedLoc, widthIndex, outWidthChunkStart, anchorOp);
Value inputWidthOffset = createOrFoldAffineApply(rewriter,
nestedLoc,
getAffineDimExpr(0, rewriter.getContext()) * state.strideWidth,
ValueRange {laneWithChunkOffset},
anchorOp);
Value patch = createConvInputPatch(inputArg,
patchType,
cBatchIndex,
cZero,
cInputHeightOffset,
inputWidthOffset,
state.dilationHeight,
state.dilationWidth,
rewriter,
nestedLoc);
Value patchRow = tensor::CollapseShapeOp::create(rewriter,
nestedLoc,
collapsedPatchType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value paddedPatchRow = createZeroPaddedTensor(
patchRow, paddedChunkRowType, {0, 0}, {0, paddedK - geo.k}, rewriter, nestedLoc);
auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(elementType));
Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType);
auto kLoop = buildNormalizedScfFor(
rewriter,
nestedLoc,
cZero,
cKSlices,
cOne,
ValueRange {zeroRow},
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
Value acc = reduceIterArgs.front();
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar);
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
Value aTile = tensor::ExtractSliceOp::create(
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
Value bTile = extractStaticSliceOrIdentity(
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
reduceYielded.push_back(
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
return success();
});
if (failed(kLoop))
return failure();
Value reduced = kLoop->results.front();
if (state.hasBias)
reduced = spatial::SpatVAddOp::create(rewriter, nestedLoc, paddedRowType, reduced, biasArg).getResult();
Value row = reduced;
if (state.numChannelsOut != xbarDim) {
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
row = tensor::ExtractSliceOp::create(
rewriter, nestedLoc, rowType, reduced, rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
}
SmallVector<OpFoldResult> outputOffsets {widthIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
Value updatedRows = tensor::InsertSliceOp::create(
rewriter, nestedLoc, row, iterArgs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
yielded.push_back(updatedRows);
return success();
});
if (failed(widthLoop))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, widthLoop->results.front());
return success();
};
if (failed(buildChunk())) {
rewriter.setInsertionPointAfter(chunkCompute);
rewriter.eraseOp(chunkCompute);
return failure();
}
rewriter.setInsertionPointAfter(chunkCompute);
chunkRows.push_back(chunkCompute.getResult(0));
}
}
}
auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut},
elementType);
return createCollectedConvOutput(
chunkRows, state.outType, cast<RankedTensorType>(chunkRows.front().getType()), nhwcType, state.outType, totalPatches,
state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc);
}
static Value buildPackedWeights(DenseElementsAttr wDenseAttr,
Value wTrans,
const ConvLoweringState& state,
const ConvGemmPlan& plan,
PatternRewriter& rewriter,
Location loc) {
if (plan.effectiveMaxParallelPixels == 1)
return wTrans;
auto packedWeightType = RankedTensorType::get(
{plan.effectiveMaxParallelPixels * plan.patchSize, plan.effectiveMaxParallelPixels * state.numChannelsOut},
state.wType.getElementType());
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
cast<Attribute>(rewriter.getZeroAttr(state.wType.getElementType())));
for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) {
for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) {
for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) {
for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) {
for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) {
const int64_t sourceFlatIndex =
(((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW;
const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW;
const int64_t targetRow = copyId * plan.patchSize + patchIndex;
const int64_t targetCol = copyId * state.numChannelsOut + outChannel;
packedValues[targetRow * packedWeightType.getDimSize(1) + targetCol] = sourceValues[sourceFlatIndex];
}
}
}
}
}
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
}
static Value buildPackedBias(Value gemmBias,
Value biasMatrix,
DenseElementsAttr biasDenseAttr,
const ConvLoweringState& state,
const ConvGemmPlan& plan,
PatternRewriter& rewriter,
Location loc) {
if (!state.hasBias)
return gemmBias;
if (plan.effectiveMaxParallelPixels == 1)
return biasMatrix;
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
SmallVector<Attribute> packedValues;
packedValues.reserve(plan.effectiveMaxParallelPixels * state.numChannelsOut);
for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId)
packedValues.append(sourceValues.begin(), sourceValues.end());
auto packedBiasType =
RankedTensorType::get({1, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.outType.getElementType());
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
}
static ConvGemmPlan
buildConvGemmPlan(const ConvLoweringState& state,
bool canPackWeightsAsConstants,
bool canPackBiasAsConstants,
int64_t chunkStart,
int64_t chunkNumPatches,
std::optional<int64_t> forcedPackFactor) {
ConvGemmPlan plan;
plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth;
plan.numPatchesPerBatch = state.outHeight * state.outWidth;
plan.globalNumPatches = state.batchSize * plan.numPatchesPerBatch;
plan.chunkStart = chunkStart;
plan.chunkNumPatches = chunkNumPatches;
const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut);
plan.maxParallelPixels = forcedPackFactor
? *forcedPackFactor
: std::max<int64_t>(1, static_cast<int64_t>(crossbarSize.getValue()) / wMaxDim);
plan.effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1;
plan.packedNumRows = ceilIntegerDivide(plan.chunkNumPatches, plan.effectiveMaxParallelPixels);
auto elemType = state.xType.getElementType();
auto outElemType = state.outType.getElementType();
plan.im2colType = RankedTensorType::get({plan.chunkNumPatches, plan.patchSize}, elemType);
plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType);
plan.gemmInputRowsType =
RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType);
plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType());
plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType());
plan.gemmOutType = RankedTensorType::get({plan.chunkNumPatches, state.numChannelsOut}, outElemType);
plan.gemmOutputRowsType =
RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType);
plan.nhwcType =
RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, outElemType);
return plan;
}
static Value createIm2colRows(const ConvLoweringState& state,
const PreparedConvInput& preparedInput,
const ConvGemmPlan& plan,
PatternRewriter& rewriter,
Location loc) {
constexpr size_t numInputs = 1;
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, preparedInput.value, [&](Value xArg) {
auto elemType = preparedInput.type.getElementType();
// Keep the standard im2col view of convolution, flipped so filters sit in
// B / crossbar columns:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut]
// Gemm output: [numPatches, cOut]
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, plan.im2colType.getShape(), elemType);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
auto im2colLoop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cNumPatches,
c1,
ValueRange {im2colInit},
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value im2colAcc = iterArgs.front();
Value batchIndex =
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
Value batchPatchIndex =
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value inputHeightOffset =
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
Value inputWidthOffset =
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
auto patchType =
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
Value patch = createConvInputPatch(xArg,
patchType,
batchIndex,
c0,
inputHeightOffset,
inputWidthOffset,
state.dilationHeight,
state.dilationWidth,
rewriter,
nestedLoc);
Value row = tensor::CollapseShapeOp::create(rewriter,
nestedLoc,
plan.im2colRowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
SmallVector<OpFoldResult> rowOffsets {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(plan.patchSize)};
Value next = tensor::InsertSliceOp::create(
rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
yielded.push_back(next);
return success();
});
if (failed(im2colLoop))
return failure();
Value gemmInputRows = im2colLoop->results.front();
// Pack N old im2col rows into one longer row so one GEMM can cover N
// pixels in parallel. The corresponding packed weight matrix contains N
// block-diagonal copies of W^T, and the packed output must be unpacked
// back to one row per spatial patch.
if (plan.effectiveMaxParallelPixels != 1)
gemmInputRows = packRowsForParallelGemm(gemmInputRows, plan.im2colType, plan.effectiveMaxParallelPixels, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
return success();
});
assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed");
return im2colComputeOp->getResult(0);
}
static Value maybeUnpackChunkRows(Value gemmRows,
const ConvGemmPlan& plan,
PatternRewriter& rewriter,
Location loc) {
if (plan.effectiveMaxParallelPixels == 1)
return gemmRows;
auto unpackedType = RankedTensorType::get(
{plan.chunkNumPatches, plan.gemmOutType.getDimSize(1)}, plan.gemmOutType.getElementType(), plan.gemmOutType.getEncoding());
auto unpackCompute = createSpatCompute<1>(rewriter, loc, TypeRange {unpackedType}, {}, gemmRows, [&](Value rowsArg) {
Value unpacked = unpackRowsFromParallelGemm(rowsArg,
cast<RankedTensorType>(rowsArg.getType()),
plan.chunkNumPatches,
plan.gemmOutType.getDimSize(1),
plan.effectiveMaxParallelPixels,
rewriter,
loc);
spatial::SpatYieldOp::create(rewriter, loc, unpacked);
});
return unpackCompute.getResult(0);
}
static Value createChunkedConvRows(const ConvLoweringState& state,
const PreparedConvInput& preparedInput,
Value weightMatrix,
Value biasMatrix,
DenseElementsAttr wDenseAttr,
DenseElementsAttr biasDenseAttr,
int64_t forcedPackFactor,
uint64_t chunkPositions,
PatternRewriter& rewriter,
Location loc) {
SmallVector<Value> chunkRows;
const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth;
for (int64_t chunkStart = 0; chunkStart < totalPatches; chunkStart += static_cast<int64_t>(chunkPositions)) {
const int64_t chunkNumPatches = std::min<int64_t>(static_cast<int64_t>(chunkPositions), totalPatches - chunkStart);
ConvGemmPlan chunkPlan = buildConvGemmPlan(state,
static_cast<bool>(wDenseAttr),
!state.hasBias || static_cast<bool>(biasDenseAttr),
chunkStart,
chunkNumPatches,
forcedPackFactor);
Value chunkInputRows = createIm2colRows(state, preparedInput, chunkPlan, rewriter, loc);
Value chunkB = buildPackedWeights(wDenseAttr, weightMatrix, state, chunkPlan, rewriter, loc);
Value gemmBias = createZeroGemmBias(chunkPlan.gemmOutputRowsType, rewriter);
if (state.hasBias)
gemmBias = state.b;
Value chunkC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, chunkPlan, rewriter, loc);
Value chunkGemmRows = ONNXGemmOp::create(rewriter,
loc,
chunkPlan.gemmOutputRowsType,
chunkInputRows,
chunkB,
chunkC,
APFloat(1.0f),
APFloat(1.0f),
/*transA=*/0,
/*transB=*/0)
.getY();
chunkRows.push_back(maybeUnpackChunkRows(chunkGemmRows, chunkPlan, rewriter, loc));
}
if (chunkRows.size() == 1)
return chunkRows.front();
auto rowType = RankedTensorType::get({totalPatches, state.numChannelsOut}, state.outType.getElementType());
auto collectRows = createSpatCompute(rewriter, loc, TypeRange {rowType}, {}, chunkRows, [&](ValueRange rows) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, rows));
});
return collectRows.getResult(0);
}
static Value rewritePackedIm2ColConv(const ConvLoweringState& state,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
Location loc) {
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (state.hasBias) {
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
}
ConvGemmPlan plan =
buildConvGemmPlan(state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0,
state.batchSize * state.outHeight * state.outWidth);
// Prepare weight matrix W for crossbar storage:
// W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout]
Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc);
Value gemmInputRows = createIm2colRows(state, preparedInput, plan, rewriter, loc);
Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc);
Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter);
if (state.hasBias)
gemmBias = state.b;
Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc);
Value gemmRows = ONNXGemmOp::create(rewriter,
loc,
plan.gemmOutputRowsType,
gemmInputRows,
gemmB,
gemmC,
APFloat(1.0f),
APFloat(1.0f),
/*transA=*/0,
/*transB=*/0)
.getY();
return createCollectedConvOutput(ValueRange {gemmRows},
state.outType,
plan.gemmOutType,
plan.nhwcType,
state.outType,
plan.chunkNumPatches,
state.numChannelsOut,
plan.effectiveMaxParallelPixels,
distributedConsumers,
rewriter,
loc);
}
static Value rewriteStreamedConv(const ConvLoweringState& state,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
Location loc,
int64_t forcedPackFactor) {
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (state.hasBias) {
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
}
ConvGemmPlan seedPlan = buildConvGemmPlan(
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0, 1, forcedPackFactor);
Value weightMatrix = createWeightMatrix(state.w, seedPlan, rewriter, loc);
ConvGeometry geo = buildConvGeometry(state);
uint64_t chunkPositions = chooseStreamChunkPositions(geo, forcedPackFactor);
Value collectedRows = createChunkedConvRows(state,
preparedInput,
weightMatrix,
biasMatrix,
wDenseAttr,
biasDenseAttr,
forcedPackFactor,
chunkPositions,
rewriter,
loc);
auto gemmOutType = cast<RankedTensorType>(collectedRows.getType());
auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut},
state.outType.getElementType());
return createCollectedConvOutput(
ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, state.outType, gemmOutType.getDimSize(0),
state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc);
}
} // namespace standard
static RankedTensorType getRowStripFragmentType(RankedTensorType tensorType, int64_t width) {
return RankedTensorType::get(
{tensorType.getDimSize(0), tensorType.getDimSize(1), 1, width}, tensorType.getElementType(), tensorType.getEncoding());
}
static SmallVector<DistributedFragmentInfo, 8> buildRowStripFragments(RankedTensorType tensorType) {
SmallVector<DistributedFragmentInfo, 8> fragments;
const int64_t height = tensorType.getDimSize(2);
const int64_t width = tensorType.getDimSize(3);
const int64_t channels = tensorType.getDimSize(1);
fragments.reserve(height);
for (int64_t row = 0; row < height; ++row) {
fragments.push_back(DistributedFragmentInfo {
{0, 0, row, 0},
{1, channels, 1, width},
{1, 1, 1, 1},
row,
});
}
return fragments;
}
static DistributedTensorInfo makeDistributedTensorInfo(Value storage, RankedTensorType logicalType) {
DistributedTensorInfo info;
info.storage = storage;
info.logicalType = logicalType;
info.fragments = buildRowStripFragments(logicalType);
info.laneCount = logicalType.getDimSize(2);
info.channels = logicalType.getDimSize(1);
info.height = logicalType.getDimSize(2);
info.width = logicalType.getDimSize(3);
return info;
}
static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr,
RankedTensorType fragmentType,
PatternRewriter& rewriter) {
auto denseType = cast<RankedTensorType>(denseAttr.getType());
SmallVector<Attribute> channelValues;
channelValues.reserve(fragmentType.getDimSize(1));
SmallVector<Attribute> flattened(denseAttr.getValues<Attribute>());
if (denseType.getRank() == 1) {
channelValues = flattened;
}
else if (denseType.getRank() == 2) {
channelValues = flattened;
}
else {
for (int64_t channel = 0; channel < denseType.getDimSize(1); ++channel)
channelValues.push_back(flattened[channel]);
}
SmallVector<Attribute> values;
values.reserve(fragmentType.getNumElements());
for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n)
for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel)
for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h)
for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w)
values.push_back(channelValues[channel]);
auto attr = DenseElementsAttr::get(fragmentType, values);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), attr, fragmentType);
}
static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter) {
auto zeroAttr = DenseElementsAttr::get(gemmResultType, rewriter.getZeroAttr(gemmResultType.getElementType()));
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), zeroAttr, gemmResultType);
}
static bool canDirectLowerRowStripConv(const ConvLoweringState& state, StringRef& failureReason) {
if (!canConsumeRowStripHwcInput(state, failureReason))
return false;
ConvGeometry geometry = buildConvGeometry(state);
if (state.numChannelsOut > geometry.xbarSize) {
failureReason = "unsupported_output_channels";
return false;
}
failureReason = "";
return true;
}
static FailureOr<Value> createRowStripPackedRows(Value rows,
const ConvLoweringState& state,
PatternRewriter& rewriter,
Location loc) {
auto rowsType = dyn_cast<RankedTensorType>(rows.getType());
if (!rowsType || !rowsType.hasStaticShape() || rowsType.getRank() != 2)
return failure();
if (state.batchSize != 1)
return failure();
if (state.outType.getRank() != 4 || !state.outType.hasStaticShape())
return failure();
const int64_t outHeight = state.outType.getDimSize(2);
const int64_t outWidth = state.outType.getDimSize(3);
const int64_t outChannels = state.outType.getDimSize(1);
if (rowsType.getDimSize(0) != outHeight * outWidth || rowsType.getDimSize(1) != outChannels)
return failure();
auto packedType = RankedTensorType::get({outHeight, outWidth, outChannels}, rowsType.getElementType(), rowsType.getEncoding());
auto packedRows =
createSpatCompute<1>(rewriter, loc, TypeRange {packedType}, {}, rows, [&](Value rowValues) {
Value packed = tensor::ExpandShapeOp::create(
rewriter, loc, packedType, rowValues, SmallVector<ReassociationIndices> {{0, 1}, {2}});
spatial::SpatYieldOp::create(rewriter, loc, packed);
});
return packedRows.getResult(0);
}
static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
const ConvLoweringState& state,
PatternRewriter& rewriter,
Location loc) {
auto inputType = dyn_cast<RankedTensorType>(inputHwc.getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 3)
return failure();
if (inputType.getDimSize(0) != state.xHeight || inputType.getDimSize(1) != state.xWidth
|| inputType.getDimSize(2) != state.numChannelsIn)
return failure();
StringRef failureReason;
if (!canDirectLowerRowStripConv(state, failureReason))
return failure();
ConvRowDemand demand = buildConvRowDemand(RowInterval {0, state.outHeight}, state);
if (!covers(demand.acquiredInputRows, demand.neededInputRows))
return failure();
ConvGeometry geometry = buildConvGeometry(state);
const int64_t xbarDim = geometry.xbarSize;
const int64_t patchSize = state.numChannelsIn * state.wHeight * state.wWidth;
const int64_t numKSlices = ceilIntegerDivide(patchSize, xbarDim);
const int64_t paddedK = numKSlices * xbarDim;
auto elementType = inputType.getElementType();
auto paddedInputType = RankedTensorType::get({state.xHeight + state.padHeightBegin + state.padHeightEnd,
state.xWidth + state.padWidthBegin + state.padWidthEnd,
state.numChannelsIn},
elementType,
inputType.getEncoding());
auto paddedPatchType =
RankedTensorType::get({state.wHeight, state.wWidth, 1}, elementType, inputType.getEncoding());
auto flatPatchType = RankedTensorType::get({state.wHeight * state.wWidth}, elementType, inputType.getEncoding());
auto rowChunkType = RankedTensorType::get({1, state.wHeight * state.wWidth}, elementType, inputType.getEncoding());
auto rowType = RankedTensorType::get({1, state.numChannelsOut}, state.outType.getElementType());
auto packedOutputType =
RankedTensorType::get({state.outHeight, state.outWidth, state.numChannelsOut}, state.outType.getElementType());
auto packedOutputSliceType =
RankedTensorType::get({1, 1, state.numChannelsOut}, state.outType.getElementType());
auto paddedRowType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType());
auto paddedPatchRowType = RankedTensorType::get({1, paddedK}, elementType, inputType.getEncoding());
auto paddedWeightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType());
auto weightDenseAttr = getHostConstDenseElementsAttr(state.w);
if (!weightDenseAttr)
return failure();
Value paddedWeights = standard::createPaddedInputKTiledWeightConstant(weightDenseAttr, state, paddedK, xbarDim, rewriter);
Value paddedBias;
if (state.hasBias) {
Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
auto biasMatrixType = cast<RankedTensorType>(biasMatrix.getType());
auto paddedBiasType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType());
if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b))
paddedBias = standard::createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter);
else
paddedBias = materializeOrComputeUnary(
biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) {
return standard::createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc);
});
}
auto paddedInputOp =
createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, inputHwc, [&](Value hwcInputArg) {
Value paddedInput = createZeroPaddedTensor(hwcInputArg,
paddedInputType,
{state.padHeightBegin, state.padWidthBegin, 0},
{state.padHeightEnd, state.padWidthEnd, 0},
rewriter,
loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
SmallVector<Value> batchInputs {paddedInputOp.getResult(0)};
if (state.hasBias)
batchInputs.push_back(paddedBias);
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {packedOutputType},
state.outHeight,
ValueRange {paddedWeights},
batchInputs,
[&](detail::SpatComputeBatchBodyArgs args) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
Value localHeightOffset = args.lane;
Value packedRowInit =
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
auto widthLoop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cOutWidth,
c1,
ValueRange {packedRowInit},
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
Value localWidthOffset = widthIndex;
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
auto rowLoop = buildNormalizedScfFor(
rewriter,
widthLoc,
c0,
cNumChannels,
c1,
ValueRange {rowInit},
[&](OpBuilder&, Location rowLoc, Value channel, ValueRange rowIterArgs, SmallVectorImpl<Value>& rowYielded) {
SmallVector<OpFoldResult> patchOffsets {localHeightOffset, localWidthOffset, channel};
SmallVector<OpFoldResult> patchSizes {
rewriter.getIndexAttr(state.wHeight), rewriter.getIndexAttr(state.wWidth), rewriter.getIndexAttr(1)};
Value channelPatch = tensor::ExtractSliceOp::create(
rewriter, rowLoc, paddedPatchType, args.inputs.front(), patchOffsets, patchSizes, getUnitStrides(rewriter, 3));
Value flatPatch = tensor::CollapseShapeOp::create(
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
Value rowChunk = tensor::ExpandShapeOp::create(
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
Value flatOffset = affineMulConst(
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
SmallVector<OpFoldResult> rowSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
Value nextRow = tensor::InsertSliceOp::create(
rewriter, rowLoc, rowChunk, rowIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
rowYielded.push_back(nextRow);
return success();
});
if (failed(rowLoop))
return failure();
Value paddedRow = rowLoop->results.front();
if (patchSize != paddedK)
paddedRow = createZeroPaddedTensor(
paddedRow, paddedPatchRowType, {0, 0}, {0, paddedK - patchSize}, rewriter, widthLoc);
auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(state.outType.getElementType()));
Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType);
auto kLoop = buildNormalizedScfFor(
rewriter,
widthLoc,
c0,
cNumKSlices,
c1,
ValueRange {zeroRow},
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
Value aTile = tensor::ExtractSliceOp::create(
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
Value bTile = extractStaticSliceOrIdentity(rewriter,
reduceLoc,
args.weights.front(),
paddedWeightTileType,
bOffsets,
bSizes,
getUnitStrides(rewriter, 2));
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
reduceYielded.push_back(
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
return success();
});
if (failed(kLoop))
return failure();
Value rowResult = kLoop->results.front();
if (state.hasBias)
rowResult =
spatial::SpatVAddOp::create(rewriter, widthLoc, paddedRowType, rowResult, args.inputs[1]).getResult();
Value outputRow = rowResult;
if (state.numChannelsOut != xbarDim) {
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
outputRow = tensor::ExtractSliceOp::create(
rewriter, widthLoc, rowType, rowResult, outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
}
Value outputFragment = tensor::ExpandShapeOp::create(rewriter,
widthLoc,
packedOutputSliceType,
outputRow,
SmallVector<ReassociationIndices> {{0}, {1, 2}});
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), widthIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
Value nextPackedRow = tensor::InsertSliceOp::create(
rewriter, widthLoc, outputFragment, widthIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 3));
widthYielded.push_back(nextPackedRow);
return success();
});
if (failed(widthLoop))
return failure();
SmallVector<OpFoldResult> batchOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> batchSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.outWidth), rewriter.getIndexAttr(state.numChannelsOut)};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, widthLoop->results.front(), args.outputs.front(), batchOffsets, batchSizes, getUnitStrides(rewriter, 3));
return success();
});
if (failed(batchOp))
return failure();
return batchOp->getResult(0);
}
static FailureOr<Value> createConvRowsFromRowStripInput(const ConvLoweringState& state,
[[maybe_unused]] const ConvLoweringDecision& decision,
Value rowStripInput,
PatternRewriter& rewriter,
Location loc) {
return createConvOutputFromRowStripHwc(rowStripInput, state, rewriter, loc);
}
static Value createFragmentConstant(const DistributedTensorStep& step,
RankedTensorType fragmentType,
PatternRewriter& rewriter) {
if (step.constantKind == DistributedTensorConstantKind::PerChannel)
return createPerChannelConstantFragment(step.constantAttr, fragmentType, rewriter);
Attribute splatValue = step.constantAttr.getSplatValue<Attribute>();
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(fragmentType, splatValue),
fragmentType);
}
static Value createFragmentReciprocalConstant(const DistributedTensorStep& step,
RankedTensorType fragmentType,
PatternRewriter& rewriter) {
SmallVector<APFloat> values;
if (step.constantKind == DistributedTensorConstantKind::PerChannel) {
auto denseType = cast<RankedTensorType>(step.constantAttr.getType());
SmallVector<APFloat> channelValues;
for (const APFloat& value : step.constantAttr.getValues<APFloat>())
channelValues.push_back(value);
values.reserve(fragmentType.getNumElements());
for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n)
for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel)
for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h)
for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) {
APFloat reciprocal = channelValues[channel];
APFloat one(reciprocal.getSemantics(), 1);
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant");
values.push_back(one);
}
(void)denseType;
}
else {
APFloat reciprocal = cast<DenseFPElementsAttr>(step.constantAttr).getSplatValue<APFloat>();
APFloat one(reciprocal.getSemantics(), 1);
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant");
values.assign(fragmentType.getNumElements(), one);
}
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseFPElementsAttr::get(fragmentType, values),
fragmentType);
}
[[maybe_unused]] static FailureOr<Value> createConvRowsForStrategy(const ConvLoweringState& state,
const ConvLoweringDecision& decision,
PatternRewriter& rewriter,
Location loc) {
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
PreparedConvInput preparedInput = standard::prepareInputForIm2Col(state, rewriter, loc);
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (state.hasBias) {
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
}
switch (decision.strategy) {
case PimConvLoweringLegacy:
case PimConvLoweringPackedIm2Col: {
standard::ConvGemmPlan plan = standard::buildConvGemmPlan(
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0,
state.batchSize * state.outHeight * state.outWidth);
Value weightMatrix = standard::createWeightMatrix(state.w, plan, rewriter, loc);
Value gemmInputRows = standard::createIm2colRows(state, preparedInput, plan, rewriter, loc);
Value gemmB = standard::buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc);
Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter);
if (state.hasBias)
gemmBias = state.b;
Value gemmC = standard::buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc);
Value gemmRows = ONNXGemmOp::create(rewriter,
loc,
plan.gemmOutputRowsType,
gemmInputRows,
gemmB,
gemmC,
APFloat(1.0f),
APFloat(1.0f),
/*transA=*/0,
/*transB=*/0)
.getY();
return standard::maybeUnpackChunkRows(gemmRows, plan, rewriter, loc);
}
case PimConvLoweringStreamedPatch:
case PimConvLoweringOutputChannelTiled:
case PimConvLoweringTiled2D:
case PimConvLoweringStreamedPacked: {
standard::ConvGemmPlan seedPlan = standard::buildConvGemmPlan(
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0, 1,
decision.strategy == PimConvLoweringStreamedPacked ? buildConvGeometry(state).pack : 1);
Value weightMatrix = standard::createWeightMatrix(state.w, seedPlan, rewriter, loc);
ConvGeometry geo = buildConvGeometry(state);
int64_t packFactor = decision.strategy == PimConvLoweringStreamedPacked ? geo.pack : 1;
uint64_t chunkPositions = chooseStreamChunkPositions(geo, packFactor);
return standard::createChunkedConvRows(state,
preparedInput,
weightMatrix,
biasMatrix,
wDenseAttr,
biasDenseAttr,
packFactor,
chunkPositions,
rewriter,
loc);
}
default:
return failure();
}
}
[[maybe_unused]] static FailureOr<DistributedTensorInfo> createDistributedTensorFromRows(Value rows,
RankedTensorType logicalType,
PatternRewriter& rewriter,
Location loc) {
const int64_t width = logicalType.getDimSize(3);
const int64_t height = logicalType.getDimSize(2);
auto rowsType = cast<RankedTensorType>(rows.getType());
auto rowSliceType =
RankedTensorType::get({width, logicalType.getDimSize(1)}, logicalType.getElementType(), rowsType.getEncoding());
auto channelWidthType =
RankedTensorType::get({logicalType.getDimSize(1), width}, logicalType.getElementType(), rowsType.getEncoding());
auto fragmentType = getRowStripFragmentType(logicalType, width);
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {logicalType}, height, {}, ValueRange {rows}, [&](detail::SpatComputeBatchBodyArgs args) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value rowStart = affineMulConst(rewriter, loc, args.lane, width, anchorOp);
SmallVector<OpFoldResult> rowOffsets {rowStart, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(width), rewriter.getIndexAttr(logicalType.getDimSize(1))};
Value rowSlice = tensor::ExtractSliceOp::create(
rewriter, loc, rowSliceType, args.inputs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
Value channelWidth = ONNXTransposeOp::create(
rewriter, loc, channelWidthType, rowSlice, rewriter.getI64ArrayAttr({1, 0})).getResult();
Value fragment = tensor::ExpandShapeOp::create(rewriter,
loc,
fragmentType,
channelWidth,
SmallVector<ReassociationIndices> {{0, 1}, {2, 3}});
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane,
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, fragment, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 4));
return success();
});
if (failed(batchOp))
return failure();
return makeDistributedTensorInfo(batchOp->getResult(0), logicalType);
}
[[maybe_unused]] static FailureOr<DistributedTensorInfo> applyDistributedPreservingStep(const DistributedTensorInfo& inputInfo,
const DistributedTensorStep& step,
PatternRewriter& rewriter,
Location loc) {
auto logicalType = inputInfo.logicalType;
const int64_t width = logicalType.getDimSize(3);
auto fragmentType = getRowStripFragmentType(logicalType, width);
auto batchOp = createSpatComputeBatch(rewriter,
loc,
TypeRange {logicalType},
inputInfo.laneCount,
{},
ValueRange {inputInfo.storage},
[&](detail::SpatComputeBatchBodyArgs args) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0),
args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)};
Value fragment = tensor::ExtractSliceOp::create(
rewriter, loc, fragmentType, args.inputs.front(), offsets, sizes, getUnitStrides(rewriter, 4));
switch (step.kind) {
case DistributedTensorOpKind::Relu:
fragment = spatial::SpatReluOp::create(rewriter, loc, fragmentType, fragment).getResult();
break;
case DistributedTensorOpKind::Sigmoid:
fragment = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, fragment).getResult();
break;
case DistributedTensorOpKind::Add: {
Value constant = createFragmentConstant(step, fragmentType, rewriter);
fragment =
spatial::SpatVAddOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
break;
}
case DistributedTensorOpKind::Sub: {
Value constant = createFragmentConstant(step, fragmentType, rewriter);
Value lhs = step.fragmentOnLhs ? fragment : constant;
Value rhs = step.fragmentOnLhs ? constant : fragment;
fragment = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult();
break;
}
case DistributedTensorOpKind::Mul: {
Value constant = createFragmentConstant(step, fragmentType, rewriter);
fragment =
spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
break;
}
case DistributedTensorOpKind::Div: {
Value constant = createFragmentReciprocalConstant(step, fragmentType, rewriter);
fragment =
spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
break;
}
case DistributedTensorOpKind::Conv:
return failure();
}
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, fragment, args.outputs.front(), offsets, sizes, getUnitStrides(rewriter, 4));
return success();
});
if (failed(batchOp))
return failure();
return makeDistributedTensorInfo(batchOp->getResult(0), logicalType);
}
static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType,
RankedTensorType gemmOutType,
RankedTensorType nhwcType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
Location loc) {
auto materializeSplatTensor = [&](DenseElementsAttr denseAttr, RankedTensorType targetType) {
Attribute splatValue = denseAttr.getSplatValue<Attribute>();
auto targetAttr = DenseElementsAttr::get(targetType, splatValue);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), targetAttr, targetType);
};
auto materializeReciprocalSplatTensor = [&](DenseFPElementsAttr denseAttr, RankedTensorType targetType) {
APFloat reciprocal = denseAttr.getSplatValue<APFloat>();
APFloat one(reciprocal.getSemantics(), 1);
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
assert(!(status & APFloat::opInvalidOp) && "distributed conv div consumer requires finite non-zero scalar");
return getOrCreateConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseFPElementsAttr::get(targetType, one), targetType);
};
auto applyDistributedConsumers = [&](Value fragment) {
Value current = fragment;
for (const DistributedTensorStep& step : distributedConsumers) {
auto fragmentType = cast<RankedTensorType>(current.getType());
switch (step.kind) {
case DistributedTensorOpKind::Relu:
current = spatial::SpatReluOp::create(rewriter, loc, fragmentType, current).getResult();
break;
case DistributedTensorOpKind::Sigmoid:
current = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, current).getResult();
break;
case DistributedTensorOpKind::Add: {
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
current = spatial::SpatVAddOp::create(rewriter, loc, fragmentType, current, splat).getResult();
break;
}
case DistributedTensorOpKind::Sub: {
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
Value lhs = step.fragmentOnLhs ? current : splat;
Value rhs = step.fragmentOnLhs ? splat : current;
current = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult();
break;
}
case DistributedTensorOpKind::Mul: {
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, splat).getResult();
break;
}
case DistributedTensorOpKind::Div: {
auto reciprocalAttr = cast<DenseFPElementsAttr>(step.constantAttr);
Value reciprocal = materializeReciprocalSplatTensor(reciprocalAttr, fragmentType);
current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, reciprocal).getResult();
break;
}
case DistributedTensorOpKind::Conv:
llvm_unreachable("conv-consuming distributed chains should not materialize through createCollectedConvOutput");
}
}
return current;
};
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
SmallVector<Value> transformedRows;
transformedRows.reserve(gemmRowArgs.size());
for (Value row : gemmRowArgs)
transformedRows.push_back(applyDistributedConsumers(row));
Value gemmOut;
if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows);
}
else {
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows);
gemmOut = standard::unpackRowsFromParallelGemm(
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
}
// Restore output layout:
// [numPatches, numChannelsOut]
// -> [N, Hout, Wout, Cout]
// -> [N, Cout, Hout, Wout]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
});
return collectComputeOp.getResult(0);
}
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b) {
ConvLoweringState state;
state.x = x;
state.w = w;
state.b = b;
state.xType = cast<RankedTensorType>(state.x.getType());
state.wType = cast<RankedTensorType>(state.w.getType());
state.outType = cast<RankedTensorType>(convOp.getY().getType());
if (!state.xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!state.wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!state.outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (state.xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", state.xType.getRank(), {4});
return failure();
}
if (state.wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", state.wType.getRank(), {4});
return failure();
}
if (state.outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", state.outType.getRank(), {4});
return failure();
}
state.group = convOp.getGroup();
if (state.group < 1) {
convOp.emitOpError("requires group >= 1 for Spatial lowering");
return failure();
}
state.batchSize = state.xType.getDimSize(0);
state.numChannelsIn = state.xType.getDimSize(1);
state.xHeight = state.xType.getDimSize(2);
state.xWidth = state.xType.getDimSize(3);
state.numChannelsOut = state.wType.getDimSize(0);
state.wHeight = state.wType.getDimSize(2);
state.wWidth = state.wType.getDimSize(3);
state.outHeight = state.outType.getDimSize(2);
state.outWidth = state.outType.getDimSize(3);
state.hasBias = state.b && !isa<ONNXNoneOp>(state.b.getDefiningOp());
if (state.numChannelsIn % state.group != 0) {
convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group "
<< state.group << " for Spatial lowering";
return failure();
}
if (state.numChannelsOut % state.group != 0) {
convOp.emitOpError() << "requires output channels " << state.numChannelsOut << " to be divisible by group "
<< state.group << " for Spatial lowering";
return failure();
}
state.numChannelsInPerGroup = state.numChannelsIn / state.group;
state.numChannelsOutPerGroup = state.numChannelsOut / state.group;
if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) {
convOp.emitOpError() << "requires grouped conv weight input channels " << state.wType.getDimSize(1)
<< " to match input channels per group " << state.numChannelsInPerGroup
<< " for Spatial lowering";
return failure();
}
if (state.wType.getDimSize(0) != state.numChannelsOut) {
convOp.emitOpError() << "requires weight output channels " << state.wType.getDimSize(0)
<< " to match result channels " << state.numChannelsOut << " for Spatial lowering";
return failure();
}
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
state.strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
state.strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
state.dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
state.dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
state.padHeightBegin = 0;
state.padHeightEnd = 0;
state.padWidthBegin = 0;
state.padWidthEnd = 0;
if (padsAttr) {
state.padHeightBegin = getI64Attr(*padsAttr, 0);
state.padWidthBegin = getI64Attr(*padsAttr, 1);
state.padHeightEnd = getI64Attr(*padsAttr, 2);
state.padWidthEnd = getI64Attr(*padsAttr, 3);
return state;
}
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (state.wHeight - 1) * state.dilationHeight + 1;
const int64_t effectiveKernelW = (state.wWidth - 1) * state.dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (state.outHeight - 1) * state.strideHeight + effectiveKernelH - state.xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (state.outWidth - 1) * state.strideWidth + effectiveKernelW - state.xWidth);
if (autoPad == "SAME_UPPER") {
state.padHeightBegin = totalPadH / 2;
state.padHeightEnd = totalPadH - state.padHeightBegin;
state.padWidthBegin = totalPadW / 2;
state.padWidthEnd = totalPadW - state.padWidthBegin;
}
else {
state.padHeightEnd = totalPadH / 2;
state.padHeightBegin = totalPadH - state.padHeightEnd;
state.padWidthEnd = totalPadW / 2;
state.padWidthBegin = totalPadW - state.padWidthEnd;
}
return state;
}
if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
return state;
}
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) {
return analyzeConvLoweringState(convOp, convOpAdaptor.getX(), convOpAdaptor.getW(), convOpAdaptor.getB());
}
static FailureOr<ConvLoweringState> analyzeConvLoweringState(spatial::SpatConv2DPlanOp planOp) {
ConvLoweringState state;
state.x = planOp.getInput();
state.w = planOp.getWeight();
state.b = planOp.getBias() ? planOp.getBias() : Value();
state.xType = dyn_cast<RankedTensorType>(state.x.getType());
state.wType = dyn_cast<RankedTensorType>(state.w.getType());
state.outType = dyn_cast<RankedTensorType>(planOp.getOutput().getType());
if (!state.xType || !state.wType || !state.outType)
return planOp.emitOpError("requires ranked tensor input, weight, and output"), failure();
if (!state.xType.hasStaticShape() || !state.wType.hasStaticShape() || !state.outType.hasStaticShape())
return planOp.emitOpError("requires static input, weight, and output shapes"), failure();
if (state.xType.getRank() != 4 || state.wType.getRank() != 4 || state.outType.getRank() != 4)
return planOp.emitOpError("requires rank-4 input, weight, and output tensors"), failure();
state.group = planOp.getGroup();
if (state.group < 1)
return planOp.emitOpError("requires group >= 1"), failure();
state.batchSize = state.xType.getDimSize(0);
state.numChannelsIn = state.xType.getDimSize(1);
state.xHeight = state.xType.getDimSize(2);
state.xWidth = state.xType.getDimSize(3);
state.numChannelsOut = state.wType.getDimSize(0);
state.wHeight = state.wType.getDimSize(2);
state.wWidth = state.wType.getDimSize(3);
state.outHeight = state.outType.getDimSize(2);
state.outWidth = state.outType.getDimSize(3);
state.hasBias = static_cast<bool>(planOp.getBias());
if (state.numChannelsIn % state.group != 0 || state.numChannelsOut % state.group != 0)
return planOp.emitOpError("requires input and output channels divisible by group"), failure();
state.numChannelsInPerGroup = state.numChannelsIn / state.group;
state.numChannelsOutPerGroup = state.numChannelsOut / state.group;
if (state.wType.getDimSize(1) != state.numChannelsInPerGroup)
return planOp.emitOpError("requires grouped conv weight channels to match input channels per group"), failure();
auto pads = planOp.getPads();
auto strides = planOp.getStrides();
auto dilations = planOp.getDilations();
if (pads.size() != 4 || strides.size() != 2 || dilations.size() != 2)
return planOp.emitOpError("requires 4 pads, 2 strides, and 2 dilations"), failure();
state.padHeightBegin = pads[0];
state.padWidthBegin = pads[1];
state.padHeightEnd = pads[2];
state.padWidthEnd = pads[3];
state.strideHeight = strides[0];
state.strideWidth = strides[1];
state.dilationHeight = dilations[0];
state.dilationWidth = dilations[1];
return state;
}
static FailureOr<PimConvLoweringType> resolveRequestedConvLoweringStrategy(Operation* op) {
if (!useExperimentalConvImpl)
return pimConvLowering.getValue();
if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) {
op->emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering="
<< stringifyConvLoweringStrategy(pimConvLowering);
return failure();
}
return PimConvLoweringPackedIm2Col;
}
static LogicalResult verifyForcedConvLoweringStrategy(Operation* op,
const ConvGeometry& geo,
PimConvLoweringType strategy) {
switch (strategy) {
case PimConvLoweringAuto:
case PimConvLoweringLegacy:
return success();
case PimConvLoweringDepthwise:
if (geo.isDepthwise)
return success();
return op->emitOpError("forced depthwise Conv lowering requires a depthwise convolution");
case PimConvLoweringPackedIm2Col:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
return success();
return op->emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget");
case PimConvLoweringStreamedPatch:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
return success();
return op->emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar");
case PimConvLoweringStreamedPacked:
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2)
return success();
return op->emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2");
case PimConvLoweringOutputChannelTiled:
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
return success();
return op->emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X");
case PimConvLoweringInputKTiled:
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize)
return success();
return op->emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X");
case PimConvLoweringTiled2D:
if (geo.k > geo.xbarSize && geo.c > geo.xbarSize)
return success();
return op->emitOpError("forced tiled-2d Conv lowering requires K > X and C > X");
}
llvm_unreachable("unknown conv lowering strategy");
}
static FailureOr<Value> lowerDenseSelectedConvPlan(Operation* op,
const ConvLoweringState& state,
PimConvLoweringType strategy,
PatternRewriter& rewriter,
Location loc);
static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent,
Value groupX,
Value groupW,
Value groupB,
RankedTensorType groupOutType);
static FailureOr<Value> buildConvValueForStrategy(Operation* op,
Location loc,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
const DistributedConvAnalysis& analysis,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter);
static FailureOr<Value> buildGroupedConvValue(Operation* op,
Location loc,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
PatternRewriter& rewriter);
static FailureOr<Value> lowerGroupedSelectedConvPlan(Operation* op,
const ConvLoweringState& state,
PimConvLoweringType strategy,
PatternRewriter& rewriter,
Location loc) {
ConvLoweringDecision decision {strategy, "", false, "", ""};
return buildGroupedConvValue(op, loc, state, decision, rewriter);
}
static FailureOr<Value> lowerDenseSelectedConvPlan(Operation* op,
const ConvLoweringState& state,
PimConvLoweringType strategy,
PatternRewriter& rewriter,
Location loc) {
DistributedConvAnalysis analysis;
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
analysis.barrierDetail = "selected dense layout";
ConvLoweringDecision decision {strategy, "", false, "", ""};
return buildConvValueForStrategy(op, loc, state, decision, analysis, {}, rewriter);
}
static FailureOr<Value> buildConvValueForStrategy(Operation* op,
Location loc,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
const DistributedConvAnalysis& analysis,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter) {
(void)analysis;
const ConvGeometry geo = buildConvGeometry(state);
switch (decision.strategy) {
case PimConvLoweringDepthwise: {
return depthwise::rewriteConv(op, state, rewriter, loc);
}
case PimConvLoweringLegacy:
case PimConvLoweringPackedIm2Col: {
return standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, loc);
}
case PimConvLoweringStreamedPatch:
case PimConvLoweringOutputChannelTiled:
case PimConvLoweringTiled2D: {
return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, /*forcedPackFactor=*/1);
}
case PimConvLoweringInputKTiled: {
return standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, loc);
}
case PimConvLoweringStreamedPacked: {
return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, geo.pack);
}
case PimConvLoweringAuto:
break;
}
op->emitOpError("unexpected auto strategy at Conv lowering dispatch");
return failure();
}
static LogicalResult
createConvValueForStrategy(ONNXConvOp convOp,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
const DistributedConvAnalysis& analysis,
ArrayRef<DistributedTensorStep> distributedConsumers,
PatternRewriter& rewriter,
FailureOr<Value>& result) {
result = buildConvValueForStrategy(convOp, convOp.getLoc(), state, decision, analysis, distributedConsumers, rewriter);
if (failed(result))
return failure();
const ConvGeometry geo = buildConvGeometry(state);
const ConvStrategyEstimate estimate = estimateConvStrategy(geo, decision.strategy, analysis);
switch (decision.strategy) {
case PimConvLoweringDepthwise:
reportConvLoweringDecision(
convOp, geo, decision, estimate, /*batchSize=*/geo.p, /*numberOfBatches=*/1, /*usesComputeBatch=*/true,
/*usesBatchedInstructionEmission=*/true, std::nullopt);
return success();
case PimConvLoweringLegacy:
case PimConvLoweringPackedIm2Col:
reportConvLoweringDecision(
convOp, geo, decision, estimate, /*batchSize=*/geo.pack, /*numberOfBatches=*/1, /*usesComputeBatch=*/true,
/*usesBatchedInstructionEmission=*/true, std::nullopt);
return success();
case PimConvLoweringStreamedPatch:
case PimConvLoweringOutputChannelTiled:
case PimConvLoweringTiled2D: {
uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1);
const int64_t batches = ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions));
reportConvLoweringDecision(convOp,
geo,
decision,
estimate,
/*batchSize=*/1,
batches,
/*usesComputeBatch=*/true,
/*usesBatchedInstructionEmission=*/true,
chunkPositions);
return success();
}
case PimConvLoweringInputKTiled: {
const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize);
const uint64_t maxLanesPerBatch =
std::max<uint64_t>(1,
static_cast<uint64_t>(crossbarCountInCore.getValue())
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
const uint64_t rowChunkWidth = std::max<uint64_t>(
1,
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
maxLanesPerBatch,
static_cast<uint64_t>(state.outWidth)}));
const int64_t batches =
state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast<int64_t>(rowChunkWidth));
reportConvLoweringDecision(convOp,
geo,
decision,
estimate,
/*batchSize=*/1,
batches,
/*usesComputeBatch=*/false,
/*usesBatchedInstructionEmission=*/false,
rowChunkWidth);
return success();
}
case PimConvLoweringStreamedPacked: {
uint64_t chunkPositions = chooseStreamChunkPositions(geo, geo.pack);
const int64_t batches = ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions));
reportConvLoweringDecision(convOp,
geo,
decision,
estimate,
/*batchSize=*/geo.pack,
batches,
/*usesComputeBatch=*/true,
/*usesBatchedInstructionEmission=*/true,
chunkPositions);
return success();
}
case PimConvLoweringAuto:
break;
}
return convOp.emitOpError("unexpected auto strategy at Conv lowering dispatch");
}
static LogicalResult
rewriteSelectedConv(ONNXConvOp convOp,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
const DistributedConvAnalysis& analysis,
PatternRewriter& rewriter) {
FailureOr<Value> result = failure();
if (failed(createConvValueForStrategy(convOp, state, decision, analysis, analysis.steps, rewriter, result)))
return failure();
if (!analysis.hasLocalConsumers()) {
rewriter.replaceOp(convOp, *result);
return success();
}
assert(analysis.replacementOp && "conv rewrite expects a replacement op");
rewriter.replaceOp(analysis.replacementOp, *result);
for (auto it = analysis.steps.rbegin(); it != analysis.steps.rend(); ++it)
if (it->op != analysis.replacementOp)
rewriter.eraseOp(it->op);
rewriter.eraseOp(convOp);
return success();
}
[[maybe_unused]] static LogicalResult
rewriteUngroupedConv(ONNXConvOp convOp,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
const DistributedConvAnalysis& analysis,
PatternRewriter& rewriter) {
return rewriteSelectedConv(convOp, state, decision, analysis, rewriter);
}
static LogicalResult
rewriteGroupedConv(ONNXConvOp convOp,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
PatternRewriter& rewriter);
static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent,
Value groupX,
Value groupW,
Value groupB,
RankedTensorType groupOutType);
static ConvLoweringState makeGroupedConvLoweringState(
const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) {
ConvLoweringState state = parent;
state.x = groupX;
state.w = groupW;
state.b = groupB;
state.xType = cast<RankedTensorType>(groupX.getType());
state.wType = cast<RankedTensorType>(groupW.getType());
state.outType = groupOutType;
state.batchSize = state.xType.getDimSize(0);
state.numChannelsIn = state.xType.getDimSize(1);
state.xHeight = state.xType.getDimSize(2);
state.xWidth = state.xType.getDimSize(3);
state.numChannelsOut = state.wType.getDimSize(0);
state.wHeight = state.wType.getDimSize(2);
state.wWidth = state.wType.getDimSize(3);
state.outHeight = state.outType.getDimSize(2);
state.outWidth = state.outType.getDimSize(3);
state.group = 1;
state.numChannelsInPerGroup = state.numChannelsIn;
state.numChannelsOutPerGroup = state.numChannelsOut;
state.hasBias = static_cast<bool>(groupB);
return state;
}
static FailureOr<Value> buildGroupedConvValue(Operation* op,
Location loc,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
PatternRewriter& rewriter) {
SmallVector<Value> xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, loc);
SmallVector<Value> wSlices = sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, loc);
SmallVector<Value> bSlices;
if (state.hasBias) {
auto biasType = cast<RankedTensorType>(state.b.getType());
int64_t biasAxis = -1;
if (biasType.getRank() == 1)
biasAxis = 0;
else if (biasType.getRank() == 2)
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
else {
op->emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
<< biasType.getRank();
return failure();
}
bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, loc);
}
if (xSlices.size() != static_cast<size_t>(state.group) || wSlices.size() != static_cast<size_t>(state.group)
|| (state.hasBias && bSlices.size() != static_cast<size_t>(state.group))) {
op->emitOpError("failed to partition grouped convolution operands for Spatial lowering");
return failure();
}
SmallVector<Value> groupResults;
groupResults.reserve(state.group);
auto groupOutType = RankedTensorType::get(
{state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType());
for (int64_t groupId = 0; groupId < state.group; groupId++) {
Value groupX = xSlices[groupId];
Value groupW = wSlices[groupId];
Value groupB = state.hasBias ? bSlices[groupId] : Value();
ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType);
DistributedConvAnalysis groupAnalysis;
groupAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv;
groupAnalysis.barrierDetail = "grouped convolution still materializes densely";
FailureOr<Value> groupResult =
buildConvValueForStrategy(op, loc, groupState, decision, groupAnalysis, {}, rewriter);
if (failed(groupResult))
return failure();
groupResults.push_back(*groupResult);
}
if (llvm::all_of(groupResults, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
});
return concatCompute.getResult(0);
}
[[maybe_unused]] static LogicalResult
rewriteGroupedConv(ONNXConvOp convOp,
const ConvLoweringState& state,
const ConvLoweringDecision& decision,
PatternRewriter& rewriter) {
FailureOr<Value> result = buildGroupedConvValue(convOp.getOperation(), convOp.getLoc(), state, decision, rewriter);
if (failed(result))
return failure();
rewriter.replaceOp(convOp, *result);
return success();
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(convOp, convOpAdaptor);
if (failed(state))
return failure();
SmallVector<int64_t> pads {
state->padHeightBegin, state->padWidthBegin, state->padHeightEnd, state->padWidthEnd};
SmallVector<int64_t> strides {state->strideHeight, state->strideWidth};
SmallVector<int64_t> dilations {state->dilationHeight, state->dilationWidth};
Value bias = state->hasBias ? convOpAdaptor.getB() : Value();
auto convPlan = spatial::SpatConv2DPlanOp::create(rewriter,
convOp.getLoc(),
convOp.getY().getType(),
convOpAdaptor.getX(),
convOpAdaptor.getW(),
bias,
rewriter.getDenseI64ArrayAttr(pads),
rewriter.getDenseI64ArrayAttr(strides),
rewriter.getDenseI64ArrayAttr(dilations),
rewriter.getI64IntegerAttr(state->group),
rewriter.getStringAttr("nchw"));
rewriter.replaceOp(convOp, convPlan.getResult());
return success();
}
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp) {
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
if (failed(state))
return failure();
if (state->group != 1 || state->batchSize != 1)
return failure();
if (state->outType.getRank() != 4 || !state->outType.hasStaticShape())
return failure();
FailureOr<PimConvLoweringType> requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation());
if (failed(requestedStrategy))
return failure();
DistributedConvAnalysis analysis;
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
analysis.barrierDetail = "selected row-strip layout";
ConvGeometry geometry = buildConvGeometry(*state);
ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis);
if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)
&& *requestedStrategy == PimConvLoweringAuto) {
decision = {PimConvLoweringLegacy,
"depthwise auto fallback when structured depthwise lowering is not representable",
/*isAuto=*/true,
"",
""};
}
if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy)))
return failure();
switch (decision.strategy) {
case PimConvLoweringLegacy:
case PimConvLoweringPackedIm2Col:
case PimConvLoweringStreamedPatch:
case PimConvLoweringOutputChannelTiled:
case PimConvLoweringTiled2D:
case PimConvLoweringStreamedPacked:
return success();
case PimConvLoweringAuto:
case PimConvLoweringDepthwise:
case PimConvLoweringInputKTiled:
return failure();
}
llvm_unreachable("unknown conv lowering strategy");
}
LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp) {
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
if (failed(state))
return failure();
StringRef failureReason;
return canDirectLowerRowStripConv(*state, failureReason) ? success() : failure();
}
FailureOr<Value>
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
std::optional<Value> rowStripInput,
bool emitRowStripLayout,
PatternRewriter& rewriter) {
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
if (failed(state))
return failure();
FailureOr<PimConvLoweringType> requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation());
if (failed(requestedStrategy))
return failure();
DistributedConvAnalysis analysis;
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
analysis.barrierDetail = emitRowStripLayout ? "selected row-strip layout" : "selected dense layout";
ConvGeometry geometry = buildConvGeometry(*state);
ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis);
if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)
&& *requestedStrategy == PimConvLoweringAuto) {
decision = {PimConvLoweringLegacy,
"depthwise auto fallback when structured depthwise lowering is not representable",
/*isAuto=*/true,
"",
""};
}
if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy)))
return failure();
if (emitRowStripLayout) {
if (rowStripInput) {
if (failed(canConsumeAndProduceRowStrip(planOp)))
return planOp.emitOpError("selected row-strip input/output layout is not supported for this Conv plan"), failure();
return createConvRowsFromRowStripInput(*state, decision, *rowStripInput, rewriter, planOp.getLoc());
}
if (failed(canLowerConvPlanToRowStrip(planOp)))
return planOp.emitOpError("selected row-strip layout is not supported for this Conv plan"), failure();
FailureOr<Value> rows = createConvRowsForStrategy(*state, decision, rewriter, planOp.getLoc());
if (failed(rows))
return failure();
FailureOr<Value> packedRows = createRowStripPackedRows(*rows, *state, rewriter, planOp.getLoc());
if (failed(packedRows))
return planOp.emitOpError("failed to pack Conv rows into the selected row-strip physical layout"), failure();
return *packedRows;
}
if (decision.strategy == PimConvLoweringDepthwise)
return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
if (state->group != 1)
return lowerGroupedSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
}
} // namespace onnx_mlir