3846 lines
184 KiB
C++
3846 lines
184 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <algorithm>
|
|
#include <cctype>
|
|
#include <map>
|
|
#include <mutex>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
|
ONNXConvOpAdaptor convOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
struct ConvLoweringDecision {
|
|
PimConvLoweringType strategy;
|
|
std::string reason;
|
|
bool isAuto = false;
|
|
std::string fallbackReason;
|
|
std::string rejectedAutoStrategy;
|
|
};
|
|
|
|
struct PreparedConvInput {
|
|
Value value;
|
|
RankedTensorType type;
|
|
};
|
|
|
|
struct ConvStrategyEstimate {
|
|
uint64_t estimatedMvmCount = 0;
|
|
uint64_t estimatedReductionVAddCount = 0;
|
|
uint64_t estimatedOutputFragments = 1;
|
|
bool perOutputPositionReduction = false;
|
|
bool requiresFuncReturnMaterialization = false;
|
|
bool giantCollectorConcatExpected = false;
|
|
bool fullInputBroadcastExpected = false;
|
|
uint64_t concatOperandCount = 0;
|
|
std::string materializationKind = "none";
|
|
std::string collectorCore = "-";
|
|
};
|
|
|
|
struct ConvReportEntry {
|
|
uint64_t id;
|
|
std::string where;
|
|
std::string strategy;
|
|
std::string mode;
|
|
std::string inputShape;
|
|
std::string weightShape;
|
|
std::string outputShape;
|
|
int64_t groups;
|
|
int64_t k;
|
|
int64_t c;
|
|
int64_t p;
|
|
int64_t xbarSize;
|
|
int64_t pack;
|
|
uint64_t im2colElements;
|
|
uint64_t im2colBudget;
|
|
std::string chunkText;
|
|
int64_t batchSize;
|
|
int64_t numberOfBatches;
|
|
std::string spatialComputeBatch;
|
|
std::string batchedInstructionEmission;
|
|
std::string reason;
|
|
std::string fallbackReason;
|
|
std::string rejectedAutoStrategy;
|
|
uint64_t estimatedMvmCount;
|
|
uint64_t estimatedReductionVAddCount;
|
|
uint64_t estimatedOutputFragments;
|
|
std::string materializationRequiredAtReturn;
|
|
std::string materializationKind;
|
|
uint64_t concatOperandCount;
|
|
std::string collectorCore;
|
|
std::string giantCollectorConcatExpected;
|
|
std::string fullInputBroadcastExpected;
|
|
};
|
|
|
|
enum class DistributedTensorOpKind {
|
|
Relu,
|
|
Sigmoid,
|
|
Add,
|
|
Sub,
|
|
Mul,
|
|
Div,
|
|
Conv,
|
|
};
|
|
|
|
enum class DistributedConvBarrierKind {
|
|
Return,
|
|
UnsupportedConsumer,
|
|
Fanout,
|
|
DeadValue,
|
|
GroupedConv,
|
|
Depthwise,
|
|
};
|
|
|
|
enum class DistributedTensorConstantKind {
|
|
None,
|
|
Splat,
|
|
PerChannel,
|
|
};
|
|
|
|
enum class DistributedTensorLayoutKind {
|
|
NchwRowStrip,
|
|
};
|
|
|
|
struct DistributedFragmentInfo {
|
|
SmallVector<int64_t, 4> offsets;
|
|
SmallVector<int64_t, 4> sizes;
|
|
SmallVector<int64_t, 4> strides;
|
|
int64_t producerLane = 0;
|
|
};
|
|
|
|
struct DistributedTensorInfo {
|
|
Value storage;
|
|
RankedTensorType logicalType;
|
|
DistributedTensorLayoutKind layoutKind = DistributedTensorLayoutKind::NchwRowStrip;
|
|
SmallVector<DistributedFragmentInfo, 8> fragments;
|
|
int64_t laneCount = 0;
|
|
int64_t fragmentHeight = 1;
|
|
int64_t channels = 0;
|
|
int64_t height = 0;
|
|
int64_t width = 0;
|
|
|
|
bool isRowStripNchw() const { return layoutKind == DistributedTensorLayoutKind::NchwRowStrip; }
|
|
};
|
|
|
|
struct DistributedTensorRegistry {
|
|
llvm::DenseMap<Value, DistributedTensorInfo> infos;
|
|
|
|
void bind(Value value, const DistributedTensorInfo& info) { infos[value] = info; }
|
|
|
|
const DistributedTensorInfo* lookup(Value value) const {
|
|
auto it = infos.find(value);
|
|
if (it == infos.end())
|
|
return nullptr;
|
|
return &it->second;
|
|
}
|
|
};
|
|
|
|
struct DistributedTensorStep {
|
|
Operation* op = nullptr;
|
|
DistributedTensorOpKind kind;
|
|
DenseElementsAttr constantAttr;
|
|
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
|
|
bool fragmentOnLhs = true;
|
|
std::optional<ConvLoweringState> convState;
|
|
};
|
|
|
|
struct DistributedConvAnalysis {
|
|
SmallVector<DistributedTensorStep, 4> steps;
|
|
Operation* replacementOp = nullptr;
|
|
DistributedConvBarrierKind barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
|
|
std::string barrierDetail;
|
|
|
|
bool hasLocalConsumers() const { return !steps.empty(); }
|
|
bool hasDistributedConvConsumer() const {
|
|
return llvm::any_of(steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; });
|
|
}
|
|
};
|
|
|
|
struct DistributedChainReportEntry {
|
|
uint64_t chainId = 0;
|
|
uint64_t chainLength = 0;
|
|
std::string producerKind;
|
|
std::string distributedOps;
|
|
std::string materializationPoints;
|
|
std::string firstMaterializationReason;
|
|
uint64_t maxLiveFragments = 0;
|
|
uint64_t maxFragmentFanout = 0;
|
|
uint64_t patchBuilderCoreCount = 0;
|
|
std::string centralJunctionDetected = "no";
|
|
std::string convInputMaterializationKind = "dense_materialization_fallback";
|
|
uint64_t localPatchFragments = 0;
|
|
uint64_t remotePatchFragments = 0;
|
|
uint64_t haloTransferCount = 0;
|
|
uint64_t groupedTransferCount = 0;
|
|
std::string fallbackReason;
|
|
};
|
|
|
|
struct DistributedConvReportTotals {
|
|
uint64_t totalConvs = 0;
|
|
uint64_t distributedTensorsCreated = 0;
|
|
uint64_t distributedValuesPropagated = 0;
|
|
uint64_t distributedConsumersHandled = 0;
|
|
uint64_t distributedConvInputsSeen = 0;
|
|
uint64_t distributedConvInputsConsumed = 0;
|
|
uint64_t materializationBarriersInserted = 0;
|
|
std::map<std::string, uint64_t> fallbackReasons;
|
|
std::map<std::string, uint64_t> barrierReasons;
|
|
SmallVector<DistributedChainReportEntry, 16> chains;
|
|
};
|
|
|
|
static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter);
|
|
static FailureOr<Value> createRowStripPackedRows(Value rows,
|
|
const ConvLoweringState& state,
|
|
PatternRewriter& rewriter,
|
|
Location loc);
|
|
|
|
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
|
|
|
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
|
|
switch (kind) {
|
|
case DistributedConvBarrierKind::Return: return "func.return";
|
|
case DistributedConvBarrierKind::UnsupportedConsumer: return "unsupported_consumer";
|
|
case DistributedConvBarrierKind::Fanout: return "fanout";
|
|
case DistributedConvBarrierKind::DeadValue: return "dead_value";
|
|
case DistributedConvBarrierKind::GroupedConv: return "grouped_conv";
|
|
case DistributedConvBarrierKind::Depthwise: return "depthwise_conv";
|
|
}
|
|
llvm_unreachable("unknown distributed conv barrier kind");
|
|
}
|
|
|
|
static StringRef stringifyConvLoweringStrategy(PimConvLoweringType strategy) {
|
|
switch (strategy) {
|
|
case PimConvLoweringAuto: return "auto";
|
|
case PimConvLoweringLegacy: return "legacy";
|
|
case PimConvLoweringDepthwise: return "depthwise";
|
|
case PimConvLoweringPackedIm2Col: return "packed-im2col";
|
|
case PimConvLoweringStreamedPatch: return "streamed-patch";
|
|
case PimConvLoweringStreamedPacked: return "streamed-packed";
|
|
case PimConvLoweringOutputChannelTiled: return "output-channel-tiled";
|
|
case PimConvLoweringInputKTiled: return "input-k-tiled";
|
|
case PimConvLoweringTiled2D: return "tiled-2d";
|
|
}
|
|
llvm_unreachable("unknown conv lowering strategy");
|
|
}
|
|
|
|
static bool requiresFuncReturnMaterialization(const DistributedConvAnalysis& analysis) {
|
|
return !analysis.hasLocalConsumers() && analysis.barrierKind == DistributedConvBarrierKind::Return;
|
|
}
|
|
|
|
static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
|
|
PimConvLoweringType strategy,
|
|
const DistributedConvAnalysis& analysis) {
|
|
ConvStrategyEstimate estimate;
|
|
estimate.requiresFuncReturnMaterialization = requiresFuncReturnMaterialization(analysis);
|
|
estimate.materializationKind = estimate.requiresFuncReturnMaterialization ? "func.return" : "none";
|
|
|
|
switch (strategy) {
|
|
case PimConvLoweringLegacy:
|
|
case PimConvLoweringPackedIm2Col:
|
|
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
|
|
break;
|
|
case PimConvLoweringStreamedPatch:
|
|
case PimConvLoweringStreamedPacked:
|
|
case PimConvLoweringOutputChannelTiled: {
|
|
uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1);
|
|
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
|
|
estimate.estimatedOutputFragments =
|
|
std::max<uint64_t>(1, static_cast<uint64_t>(ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions))));
|
|
break;
|
|
}
|
|
case PimConvLoweringInputKTiled: {
|
|
const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize);
|
|
const uint64_t maxLanesPerBatch =
|
|
std::max<uint64_t>(1,
|
|
static_cast<uint64_t>(crossbarCountInCore.getValue())
|
|
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
|
|
const uint64_t rowChunkWidth = std::max<uint64_t>(
|
|
1,
|
|
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
|
|
maxLanesPerBatch,
|
|
static_cast<uint64_t>(std::max<int64_t>(1, geo.outWidth))}));
|
|
estimate.estimatedMvmCount =
|
|
static_cast<uint64_t>(std::max<int64_t>(1, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(1, numKSlices));
|
|
estimate.estimatedReductionVAddCount =
|
|
static_cast<uint64_t>(std::max<int64_t>(1, geo.p))
|
|
* static_cast<uint64_t>(std::max<int64_t>(0, numKSlices - 1) + (geo.hasBias ? 1 : 0));
|
|
estimate.estimatedOutputFragments = static_cast<uint64_t>(std::max<int64_t>(1, geo.batchSize))
|
|
* static_cast<uint64_t>(std::max<int64_t>(1, geo.outHeight))
|
|
* static_cast<uint64_t>(
|
|
ceilIntegerDivide(geo.outWidth, static_cast<int64_t>(rowChunkWidth)));
|
|
estimate.perOutputPositionReduction = numKSlices > 1;
|
|
estimate.fullInputBroadcastExpected = estimate.estimatedOutputFragments > 1;
|
|
if (estimate.requiresFuncReturnMaterialization && estimate.estimatedOutputFragments >= 128) {
|
|
estimate.giantCollectorConcatExpected = true;
|
|
estimate.materializationKind = "single_collector_concat";
|
|
estimate.concatOperandCount = estimate.estimatedOutputFragments;
|
|
estimate.collectorCore = "scheduled_post_merge";
|
|
}
|
|
break;
|
|
}
|
|
case PimConvLoweringTiled2D:
|
|
estimate.estimatedMvmCount = static_cast<uint64_t>(std::max<int64_t>(1, geo.p));
|
|
break;
|
|
case PimConvLoweringDepthwise:
|
|
case PimConvLoweringAuto:
|
|
break;
|
|
}
|
|
|
|
if (estimate.requiresFuncReturnMaterialization && estimate.materializationKind == "none")
|
|
estimate.materializationKind = "func.return";
|
|
return estimate;
|
|
}
|
|
|
|
static std::string formatShape(ArrayRef<int64_t> dims) {
|
|
std::string text;
|
|
llvm::raw_string_ostream os(text);
|
|
os << "[";
|
|
for (size_t i = 0; i < dims.size(); ++i) {
|
|
if (i != 0)
|
|
os << "x";
|
|
os << dims[i];
|
|
}
|
|
os << "]";
|
|
return text;
|
|
}
|
|
|
|
static std::string collapseWhitespace(StringRef text) {
|
|
std::string out;
|
|
out.reserve(text.size());
|
|
bool lastWasSpace = false;
|
|
for (char c : text) {
|
|
bool isSpace = std::isspace(static_cast<unsigned char>(c));
|
|
if (isSpace) {
|
|
if (!lastWasSpace && !out.empty())
|
|
out.push_back(' ');
|
|
lastWasSpace = true;
|
|
continue;
|
|
}
|
|
out.push_back(c);
|
|
lastWasSpace = false;
|
|
}
|
|
return out;
|
|
}
|
|
|
|
static std::string abbreviate(StringRef text, size_t maxLen) {
|
|
if (text.size() <= maxLen)
|
|
return text.str();
|
|
return (text.take_front(maxLen - 3) + "...").str();
|
|
}
|
|
|
|
static std::string abbreviateFromEnd(StringRef text, size_t maxLen) {
|
|
if (text.size() <= maxLen)
|
|
return text.str();
|
|
return ("..." + text.take_back(maxLen - 3)).str();
|
|
}
|
|
|
|
static std::string summarizeLocation(Location loc, size_t maxLen = 44) {
|
|
std::string text;
|
|
llvm::raw_string_ostream os(text);
|
|
loc.print(os);
|
|
os.flush();
|
|
std::string collapsed = collapseWhitespace(text);
|
|
if (collapsed.size() <= maxLen)
|
|
return collapsed;
|
|
if (collapsed.find('/') != std::string::npos || collapsed.find('#') != std::string::npos)
|
|
return abbreviateFromEnd(collapsed, maxLen);
|
|
return abbreviate(collapsed, maxLen);
|
|
}
|
|
|
|
static std::string alignCell(StringRef text, size_t width, bool rightAlign = false) {
|
|
std::string cell = text.str();
|
|
if (cell.size() < width) {
|
|
size_t padding = width - cell.size();
|
|
if (rightAlign)
|
|
cell.insert(cell.begin(), padding, ' ');
|
|
else
|
|
cell.append(padding, ' ');
|
|
}
|
|
return cell;
|
|
}
|
|
|
|
static bool hasSameStaticTensorType(Value value, Type expectedType) {
|
|
auto valueType = dyn_cast<RankedTensorType>(value.getType());
|
|
auto expectedTensorType = dyn_cast<RankedTensorType>(expectedType);
|
|
return valueType && expectedTensorType && valueType.hasStaticShape() && valueType == expectedTensorType;
|
|
}
|
|
|
|
static bool isSplatConstantValue(Value value, DenseElementsAttr& denseAttr) {
|
|
denseAttr = getHostConstDenseElementsAttr(value);
|
|
return static_cast<bool>(denseAttr) && denseAttr.isSplat();
|
|
}
|
|
|
|
static bool isPerChannelConstantValue(Value value, RankedTensorType currentType, DenseElementsAttr& denseAttr) {
|
|
denseAttr = getHostConstDenseElementsAttr(value);
|
|
if (!denseAttr || denseAttr.isSplat())
|
|
return false;
|
|
|
|
auto constantType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!constantType || !constantType.hasStaticShape())
|
|
return false;
|
|
|
|
const int64_t channels = currentType.getDimSize(1);
|
|
if (constantType.getRank() == 1)
|
|
return constantType.getDimSize(0) == channels;
|
|
if (constantType.getRank() == 2)
|
|
return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels;
|
|
if (constantType.getRank() != 4)
|
|
return false;
|
|
return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels
|
|
&& constantType.getDimSize(2) == 1 && constantType.getDimSize(3) == 1;
|
|
}
|
|
|
|
static std::optional<DistributedTensorStep>
|
|
classifyDistributedBinaryConsumer(Operation* user,
|
|
Value currentValue,
|
|
Value lhs,
|
|
Value rhs,
|
|
DistributedTensorOpKind kind,
|
|
bool allowFragmentOnRhs,
|
|
std::string& failureDetail) {
|
|
if (user->getNumResults() != 1 || !hasSameStaticTensorType(user->getResult(0), currentValue.getType())) {
|
|
failureDetail = "result type mismatch";
|
|
return std::nullopt;
|
|
}
|
|
|
|
auto currentType = cast<RankedTensorType>(currentValue.getType());
|
|
DenseElementsAttr constantAttr;
|
|
if (lhs == currentValue) {
|
|
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
|
|
if (isSplatConstantValue(rhs, constantAttr))
|
|
constantKind = DistributedTensorConstantKind::Splat;
|
|
else if (isPerChannelConstantValue(rhs, currentType, constantAttr))
|
|
constantKind = DistributedTensorConstantKind::PerChannel;
|
|
else
|
|
failureDetail = "unsupported rhs broadcast";
|
|
if (constantKind == DistributedTensorConstantKind::None)
|
|
return std::nullopt;
|
|
return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/true, std::nullopt};
|
|
}
|
|
|
|
if (rhs == currentValue && allowFragmentOnRhs) {
|
|
DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None;
|
|
if (isSplatConstantValue(lhs, constantAttr))
|
|
constantKind = DistributedTensorConstantKind::Splat;
|
|
else if (isPerChannelConstantValue(lhs, currentType, constantAttr))
|
|
constantKind = DistributedTensorConstantKind::PerChannel;
|
|
else
|
|
failureDetail = "unsupported lhs broadcast";
|
|
if (constantKind == DistributedTensorConstantKind::None)
|
|
return std::nullopt;
|
|
return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/false, std::nullopt};
|
|
}
|
|
|
|
failureDetail = "conv result is not the supported binary operand";
|
|
return std::nullopt;
|
|
}
|
|
|
|
static bool covers(RowInterval acquired, RowInterval needed) {
|
|
return acquired.begin <= needed.begin && acquired.end >= needed.end;
|
|
}
|
|
|
|
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
|
|
if (state.batchSize != 1) {
|
|
failureReason = "unsupported_batch";
|
|
return false;
|
|
}
|
|
if (state.group != 1) {
|
|
failureReason = "unsupported_groups";
|
|
return false;
|
|
}
|
|
if (isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup)) {
|
|
failureReason = "unsupported_depthwise";
|
|
return false;
|
|
}
|
|
if (state.strideHeight != 1 || state.strideWidth != 1) {
|
|
failureReason = "unsupported_stride";
|
|
return false;
|
|
}
|
|
if (state.dilationHeight != 1 || state.dilationWidth != 1) {
|
|
failureReason = "unsupported_dilation";
|
|
return false;
|
|
}
|
|
if (state.padHeightBegin != state.padHeightEnd || state.padWidthBegin != state.padWidthEnd) {
|
|
failureReason = "unsupported_padding";
|
|
return false;
|
|
}
|
|
if (state.padHeightBegin != 1 || state.padWidthBegin != 1) {
|
|
failureReason = "unsupported_padding";
|
|
return false;
|
|
}
|
|
if (state.wHeight != 3 || state.wWidth != 3) {
|
|
failureReason = "unsupported_kernel";
|
|
return false;
|
|
}
|
|
if (state.outHeight != state.xHeight || state.outWidth != state.xWidth) {
|
|
failureReason = "unsupported_output_shape";
|
|
return false;
|
|
}
|
|
if (!getHostConstDenseElementsAttr(state.w)) {
|
|
failureReason = "non_constant_weights";
|
|
return false;
|
|
}
|
|
if (state.hasBias && !getHostConstDenseElementsAttr(state.b)) {
|
|
failureReason = "non_constant_bias";
|
|
return false;
|
|
}
|
|
failureReason = "";
|
|
return true;
|
|
}
|
|
|
|
static std::string stringifyDistributedTensorOpKind(DistributedTensorOpKind kind) {
|
|
switch (kind) {
|
|
case DistributedTensorOpKind::Relu: return "Relu";
|
|
case DistributedTensorOpKind::Sigmoid: return "Sigmoid";
|
|
case DistributedTensorOpKind::Add: return "Add";
|
|
case DistributedTensorOpKind::Sub: return "Sub";
|
|
case DistributedTensorOpKind::Mul: return "Mul";
|
|
case DistributedTensorOpKind::Div: return "Div";
|
|
case DistributedTensorOpKind::Conv: return "Conv";
|
|
}
|
|
llvm_unreachable("unknown distributed tensor op kind");
|
|
}
|
|
|
|
[[maybe_unused]] static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) {
|
|
DistributedConvAnalysis analysis;
|
|
analysis.replacementOp = convOp;
|
|
|
|
Value currentValue = convOp.getResult();
|
|
while (true) {
|
|
if (currentValue.use_empty()) {
|
|
analysis.barrierKind = DistributedConvBarrierKind::DeadValue;
|
|
analysis.barrierDetail = "result has no users";
|
|
return analysis;
|
|
}
|
|
|
|
if (!currentValue.hasOneUse()) {
|
|
analysis.barrierKind = DistributedConvBarrierKind::Fanout;
|
|
analysis.barrierDetail = "value has multiple users";
|
|
return analysis;
|
|
}
|
|
|
|
Operation* user = *currentValue.getUsers().begin();
|
|
if (isa<func::ReturnOp>(user)) {
|
|
analysis.barrierKind = DistributedConvBarrierKind::Return;
|
|
analysis.barrierDetail = "materialize at func.return";
|
|
return analysis;
|
|
}
|
|
|
|
std::optional<DistributedTensorStep> step;
|
|
std::string failureDetail;
|
|
if (auto reluOp = dyn_cast<ONNXReluOp>(user)) {
|
|
if (hasSameStaticTensorType(reluOp.getResult(), currentValue.getType()))
|
|
step = DistributedTensorStep {
|
|
user, DistributedTensorOpKind::Relu, {}, DistributedTensorConstantKind::None, true, std::nullopt};
|
|
else
|
|
failureDetail = "relu result type mismatch";
|
|
}
|
|
else if (auto sigmoidOp = dyn_cast<ONNXSigmoidOp>(user)) {
|
|
if (hasSameStaticTensorType(sigmoidOp.getResult(), currentValue.getType()))
|
|
step = DistributedTensorStep {
|
|
user, DistributedTensorOpKind::Sigmoid, {}, DistributedTensorConstantKind::None, true, std::nullopt};
|
|
else
|
|
failureDetail = "sigmoid result type mismatch";
|
|
}
|
|
else if (auto addOp = dyn_cast<ONNXAddOp>(user)) {
|
|
step = classifyDistributedBinaryConsumer(
|
|
user, currentValue, addOp.getA(), addOp.getB(), DistributedTensorOpKind::Add, /*allowFragmentOnRhs=*/true,
|
|
failureDetail);
|
|
}
|
|
else if (auto subOp = dyn_cast<ONNXSubOp>(user)) {
|
|
step = classifyDistributedBinaryConsumer(
|
|
user, currentValue, subOp.getA(), subOp.getB(), DistributedTensorOpKind::Sub, /*allowFragmentOnRhs=*/true,
|
|
failureDetail);
|
|
}
|
|
else if (auto mulOp = dyn_cast<ONNXMulOp>(user)) {
|
|
step = classifyDistributedBinaryConsumer(
|
|
user, currentValue, mulOp.getA(), mulOp.getB(), DistributedTensorOpKind::Mul, /*allowFragmentOnRhs=*/true,
|
|
failureDetail);
|
|
}
|
|
else if (auto divOp = dyn_cast<ONNXDivOp>(user)) {
|
|
step = classifyDistributedBinaryConsumer(
|
|
user, currentValue, divOp.getA(), divOp.getB(), DistributedTensorOpKind::Div, /*allowFragmentOnRhs=*/false,
|
|
failureDetail);
|
|
if (step) {
|
|
auto denseAttr = dyn_cast<DenseFPElementsAttr>(step->constantAttr);
|
|
if (!denseAttr) {
|
|
failureDetail = "div requires floating-point splat constant";
|
|
step.reset();
|
|
}
|
|
}
|
|
}
|
|
else if (auto nextConv = dyn_cast<ONNXConvOp>(user)) {
|
|
failureDetail = "onnx.Conv distributed consumer blocked by dim0-only whole-batch materialization in MergeComputeNodes";
|
|
}
|
|
else {
|
|
failureDetail = (user->getName().getStringRef() + " is not distributed-aware yet").str();
|
|
}
|
|
|
|
if (!step) {
|
|
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
|
|
analysis.barrierDetail = failureDetail;
|
|
return analysis;
|
|
}
|
|
|
|
analysis.replacementOp = user;
|
|
analysis.steps.push_back(*step);
|
|
currentValue = user->getResult(0);
|
|
}
|
|
}
|
|
|
|
static void rewriteDistributedConvReport(const DistributedConvReportTotals& totals) {
|
|
std::fstream reportFile = openReportFile("conv_distributed_consumption_report");
|
|
if (!reportFile.is_open())
|
|
return;
|
|
|
|
reportFile << "# PIM Conv Distributed Consumption Report\n\n";
|
|
reportFile << "Totals:\n";
|
|
reportFile << "- convs_seen: " << totals.totalConvs << "\n";
|
|
reportFile << "- distributed_tensors_created: " << totals.distributedTensorsCreated << "\n";
|
|
reportFile << "- distributed_values_propagated: " << totals.distributedValuesPropagated << "\n";
|
|
reportFile << "- distributed_consumers_handled_locally: " << totals.distributedConsumersHandled << "\n";
|
|
reportFile << "- distributed_conv_inputs_seen: " << totals.distributedConvInputsSeen << "\n";
|
|
reportFile << "- distributed_conv_inputs_consumed: " << totals.distributedConvInputsConsumed << "\n";
|
|
reportFile << "- materialization_barriers_inserted: " << totals.materializationBarriersInserted << "\n\n";
|
|
|
|
if (!totals.barrierReasons.empty()) {
|
|
reportFile << "Materialization barriers:\n";
|
|
for (const auto& [reason, count] : totals.barrierReasons)
|
|
reportFile << "- " << reason << ": " << count << "\n";
|
|
reportFile << "\n";
|
|
}
|
|
|
|
if (!totals.fallbackReasons.empty()) {
|
|
reportFile << "Fallback / no-distribution reasons:\n";
|
|
for (const auto& [reason, count] : totals.fallbackReasons)
|
|
reportFile << "- " << reason << ": " << count << "\n";
|
|
reportFile << "\n";
|
|
}
|
|
|
|
if (!totals.chains.empty()) {
|
|
reportFile << "Chains:\n";
|
|
for (const DistributedChainReportEntry& chain : totals.chains) {
|
|
reportFile << "- chain_id: " << chain.chainId << "\n";
|
|
reportFile << " chain_length: " << chain.chainLength << "\n";
|
|
reportFile << " producer_kind: " << chain.producerKind << "\n";
|
|
reportFile << " distributed_ops: " << chain.distributedOps << "\n";
|
|
reportFile << " materialization_points: " << chain.materializationPoints << "\n";
|
|
reportFile << " first_materialization_reason: " << chain.firstMaterializationReason << "\n";
|
|
reportFile << " max_live_fragments: " << chain.maxLiveFragments << "\n";
|
|
reportFile << " max_fragment_fanout: " << chain.maxFragmentFanout << "\n";
|
|
reportFile << " patch_builder_core_count: " << chain.patchBuilderCoreCount << "\n";
|
|
reportFile << " central_junction_detected: " << chain.centralJunctionDetected << "\n";
|
|
reportFile << " conv_input_materialization_kind: " << chain.convInputMaterializationKind << "\n";
|
|
reportFile << " local_patch_fragments: " << chain.localPatchFragments << "\n";
|
|
reportFile << " remote_patch_fragments: " << chain.remotePatchFragments << "\n";
|
|
reportFile << " halo_transfer_count: " << chain.haloTransferCount << "\n";
|
|
reportFile << " grouped_transfer_count: " << chain.groupedTransferCount << "\n";
|
|
if (!chain.fallbackReason.empty())
|
|
reportFile << " fallback_reason: " << chain.fallbackReason << "\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
[[maybe_unused]] static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) {
|
|
static std::mutex reportMutex;
|
|
static DistributedConvReportTotals totals;
|
|
|
|
std::string barrierKey = stringifyDistributedConvBarrierKind(analysis.barrierKind).str();
|
|
if (!analysis.barrierDetail.empty())
|
|
barrierKey += ": " + analysis.barrierDetail;
|
|
|
|
std::lock_guard<std::mutex> guard(reportMutex);
|
|
totals.totalConvs++;
|
|
if (analysis.hasLocalConsumers()) {
|
|
totals.distributedTensorsCreated++;
|
|
totals.distributedValuesPropagated += analysis.steps.size();
|
|
totals.distributedConsumersHandled += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
|
|
return step.kind != DistributedTensorOpKind::Conv;
|
|
});
|
|
totals.distributedConvInputsSeen += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
|
|
return step.kind == DistributedTensorOpKind::Conv;
|
|
});
|
|
totals.distributedConvInputsConsumed += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) {
|
|
return step.kind == DistributedTensorOpKind::Conv;
|
|
});
|
|
totals.materializationBarriersInserted++;
|
|
totals.barrierReasons[barrierKey]++;
|
|
}
|
|
else {
|
|
totals.fallbackReasons[barrierKey]++;
|
|
}
|
|
DistributedChainReportEntry chain;
|
|
chain.chainId = totals.totalConvs;
|
|
chain.chainLength = analysis.steps.size() + 1;
|
|
chain.producerKind = "Conv";
|
|
chain.materializationPoints = stringifyDistributedConvBarrierKind(analysis.barrierKind).str();
|
|
chain.firstMaterializationReason = analysis.barrierDetail;
|
|
chain.maxFragmentFanout = analysis.barrierKind == DistributedConvBarrierKind::Fanout ? 2 : 1;
|
|
std::string ops;
|
|
for (size_t index = 0; index < analysis.steps.size(); ++index) {
|
|
if (!ops.empty())
|
|
ops += ", ";
|
|
ops += stringifyDistributedTensorOpKind(analysis.steps[index].kind);
|
|
}
|
|
chain.distributedOps = ops;
|
|
if (analysis.hasDistributedConvConsumer()) {
|
|
chain.convInputMaterializationKind = "distributed_with_halo_exchange";
|
|
chain.patchBuilderCoreCount = 1;
|
|
}
|
|
if (totals.chains.size() == 16)
|
|
totals.chains.erase(totals.chains.begin());
|
|
totals.chains.push_back(std::move(chain));
|
|
rewriteDistributedConvReport(totals);
|
|
}
|
|
|
|
static std::string makeDivider(ArrayRef<size_t> widths) {
|
|
std::string divider = "+";
|
|
for (size_t width : widths) {
|
|
divider.append(width + 2, '-');
|
|
divider.push_back('+');
|
|
}
|
|
return divider;
|
|
}
|
|
|
|
static void printConvReportLegend(std::fstream& reportFile) {
|
|
reportFile << "# PIM Conv Lowering Report\n\n";
|
|
reportFile << "Legend:\n";
|
|
reportFile << "- `id`: sequential Conv report entry index within this compiler invocation.\n";
|
|
reportFile << "- `where`: summarized MLIR location of the Conv op.\n";
|
|
reportFile << "- `mode`: whether the selected strategy came from `auto` policy or a forced compiler option.\n";
|
|
reportFile << "- `strategy`: selected Conv lowering algorithm.\n";
|
|
reportFile << "- `input`, `weight`, `output`: tensor shapes of the Conv operands/result.\n";
|
|
reportFile << "- `groups`: ONNX Conv group count.\n";
|
|
reportFile << "- `K`: logical reduction size, `CinPerGroup * Kh * Kw`.\n";
|
|
reportFile << "- `C`: logical output-channel width handled by the selected strategy.\n";
|
|
reportFile << "- `P`: total output positions, `N * Hout * Wout`.\n";
|
|
reportFile << "- `X`: crossbar size.\n";
|
|
reportFile << "- `pack`: packed spatial positions per MVM group, `floor(X / max(K, C))`.\n";
|
|
reportFile << "- `im2col`: total explicit im2col element count, `P * K`.\n";
|
|
reportFile << "- `im2col_budget`: maximum im2col element budget allowed by the compiler option.\n";
|
|
reportFile << "- `stream_chunk`: output positions materialized per streamed chunk, or `-` when not applicable.\n";
|
|
reportFile << "- `batch_size`: logical batch size passed into the compute lowering for this Conv form.\n";
|
|
reportFile << "- `batches`: number of repeated compute batches emitted for the Conv.\n";
|
|
reportFile << "- `spatial_compute_batch`: whether ONNX-to-Spatial used `spat.compute_batch` for this Conv.\n";
|
|
reportFile << "- `batched_instruction_emission`: whether the lowering is expected to reach batched PIM emission.\n";
|
|
reportFile << "- `reason`: strategy-selection reason.\n";
|
|
reportFile << "- Per-conv details below the table include profitability estimates and materialization diagnostics.\n";
|
|
reportFile << "- placeholders like `[7]`: value too long for the table cell; see the appendix at the end.\n\n";
|
|
}
|
|
|
|
struct ConvReportOverflow {
|
|
uint64_t placeholderId;
|
|
SmallVector<uint64_t, 4> rowIds;
|
|
std::string column;
|
|
std::string value;
|
|
};
|
|
|
|
static StringRef describeConvLoweringStrategy(PimConvLoweringType strategy) {
|
|
switch (strategy) {
|
|
case PimConvLoweringAuto: return "Automatic policy selection.";
|
|
case PimConvLoweringLegacy: return "Legacy Conv lowering path kept for compatibility.";
|
|
case PimConvLoweringDepthwise: return "Specialized depthwise lowering that avoids generic cross-channel GEMM mixing.";
|
|
case PimConvLoweringPackedIm2Col: return "Explicit im2col plus packed GEMM lowering for Conv shapes that fit well in one crossbar.";
|
|
case PimConvLoweringStreamedPatch: return "Chunked per-patch streaming Conv lowering without global im2col materialization.";
|
|
case PimConvLoweringStreamedPacked: return "Chunked streamed Conv lowering that still packs multiple output positions per MVM group.";
|
|
case PimConvLoweringOutputChannelTiled: return "Conv lowering that splits output channels across tiles when C exceeds one crossbar width.";
|
|
case PimConvLoweringInputKTiled: return "Conv lowering that splits the reduction dimension K across tiles and accumulates partial sums.";
|
|
case PimConvLoweringTiled2D: return "Conv lowering that tiles both K and output channels because neither dimension fits one crossbar.";
|
|
}
|
|
llvm_unreachable("unknown conv lowering strategy");
|
|
}
|
|
|
|
static std::string fitConvReportCell(StringRef text,
|
|
size_t width,
|
|
uint64_t rowId,
|
|
StringRef column,
|
|
std::vector<ConvReportOverflow>& overflows,
|
|
uint64_t& nextPlaceholderId,
|
|
bool rightAlign = false) {
|
|
if (text.size() <= width)
|
|
return alignCell(text, width, rightAlign);
|
|
|
|
for (ConvReportOverflow& overflow : overflows) {
|
|
if (overflow.column == column && overflow.value == text) {
|
|
if (llvm::find(overflow.rowIds, rowId) == overflow.rowIds.end())
|
|
overflow.rowIds.push_back(rowId);
|
|
std::string placeholder = "[" + std::to_string(overflow.placeholderId) + "]";
|
|
return alignCell(placeholder, width, rightAlign);
|
|
}
|
|
}
|
|
|
|
std::string placeholder = "[" + std::to_string(nextPlaceholderId++) + "]";
|
|
overflows.push_back({nextPlaceholderId - 1, {rowId}, column.str(), text.str()});
|
|
return alignCell(placeholder, width, rightAlign);
|
|
}
|
|
|
|
static void writeConvReportTable(std::fstream& reportFile, ArrayRef<ConvReportEntry> entries) {
|
|
static constexpr size_t kIdWidth = 4;
|
|
static constexpr size_t kWhereWidth = 24;
|
|
static constexpr size_t kModeWidth = 6;
|
|
static constexpr size_t kStrategyWidth = 20;
|
|
static constexpr size_t kShapeWidth = 14;
|
|
static constexpr size_t kGroupsWidth = 3;
|
|
static constexpr size_t kSmallWidth = 5;
|
|
static constexpr size_t kPWidth = 10;
|
|
static constexpr size_t kIm2colWidth = 10;
|
|
static constexpr size_t kChunkWidth = 8;
|
|
static constexpr size_t kFlagWidth = 3;
|
|
static constexpr size_t kReasonWidth = 24;
|
|
|
|
const SmallVector<size_t, 21> widths = {
|
|
kIdWidth, kWhereWidth, kModeWidth, kStrategyWidth, kShapeWidth, kShapeWidth, kShapeWidth, kGroupsWidth,
|
|
kSmallWidth, kSmallWidth, kPWidth, kSmallWidth, kSmallWidth, kIm2colWidth, kIm2colWidth,
|
|
kChunkWidth, kSmallWidth, kSmallWidth, kFlagWidth, kFlagWidth, kReasonWidth,
|
|
};
|
|
const std::string divider = makeDivider(widths);
|
|
std::vector<ConvReportOverflow> overflows;
|
|
uint64_t nextPlaceholderId = 1;
|
|
|
|
auto printRow = [&](ArrayRef<std::string> cells) {
|
|
reportFile << "|";
|
|
for (size_t i = 0; i < cells.size(); ++i)
|
|
reportFile << " " << cells[i] << " |";
|
|
reportFile << "\n";
|
|
};
|
|
|
|
reportFile << divider << "\n";
|
|
printRow({
|
|
alignCell("id", kIdWidth, true),
|
|
alignCell("where", kWhereWidth),
|
|
alignCell("mode", kModeWidth),
|
|
alignCell("strategy", kStrategyWidth),
|
|
alignCell("input", kShapeWidth),
|
|
alignCell("weight", kShapeWidth),
|
|
alignCell("output", kShapeWidth),
|
|
alignCell("grp", kGroupsWidth, true),
|
|
alignCell("K", kSmallWidth, true),
|
|
alignCell("C", kSmallWidth, true),
|
|
alignCell("P", kPWidth, true),
|
|
alignCell("X", kSmallWidth, true),
|
|
alignCell("pack", kSmallWidth, true),
|
|
alignCell("im2col", kIm2colWidth, true),
|
|
alignCell("budget", kIm2colWidth, true),
|
|
alignCell("chunk", kChunkWidth, true),
|
|
alignCell("batch", kSmallWidth, true),
|
|
alignCell("nbat", kSmallWidth, true),
|
|
alignCell("scb", kFlagWidth),
|
|
alignCell("bie", kFlagWidth),
|
|
alignCell("reason", kReasonWidth),
|
|
});
|
|
reportFile << divider << "\n";
|
|
|
|
for (const ConvReportEntry& entry : entries) {
|
|
printRow({
|
|
alignCell(std::to_string(entry.id), kIdWidth, true),
|
|
fitConvReportCell(entry.where, kWhereWidth, entry.id, "where", overflows, nextPlaceholderId),
|
|
alignCell(entry.mode, kModeWidth),
|
|
fitConvReportCell(entry.strategy, kStrategyWidth, entry.id, "strategy", overflows, nextPlaceholderId),
|
|
fitConvReportCell(entry.inputShape, kShapeWidth, entry.id, "input", overflows, nextPlaceholderId),
|
|
fitConvReportCell(entry.weightShape, kShapeWidth, entry.id, "weight", overflows, nextPlaceholderId),
|
|
fitConvReportCell(entry.outputShape, kShapeWidth, entry.id, "output", overflows, nextPlaceholderId),
|
|
alignCell(std::to_string(entry.groups), kGroupsWidth, true),
|
|
alignCell(std::to_string(entry.k), kSmallWidth, true),
|
|
alignCell(std::to_string(entry.c), kSmallWidth, true),
|
|
alignCell(std::to_string(entry.p), kPWidth, true),
|
|
alignCell(std::to_string(entry.xbarSize), kSmallWidth, true),
|
|
alignCell(std::to_string(entry.pack), kSmallWidth, true),
|
|
alignCell(std::to_string(entry.im2colElements), kIm2colWidth, true),
|
|
alignCell(std::to_string(entry.im2colBudget), kIm2colWidth, true),
|
|
fitConvReportCell(entry.chunkText, kChunkWidth, entry.id, "stream_chunk", overflows, nextPlaceholderId, true),
|
|
alignCell(std::to_string(entry.batchSize), kSmallWidth, true),
|
|
alignCell(std::to_string(entry.numberOfBatches), kSmallWidth, true),
|
|
alignCell(entry.spatialComputeBatch, kFlagWidth),
|
|
alignCell(entry.batchedInstructionEmission, kFlagWidth),
|
|
fitConvReportCell(entry.reason, kReasonWidth, entry.id, "reason", overflows, nextPlaceholderId),
|
|
});
|
|
}
|
|
reportFile << divider << "\n";
|
|
|
|
if (overflows.empty())
|
|
reportFile << "\n";
|
|
else {
|
|
reportFile << "\nAppendix:\n";
|
|
for (const ConvReportOverflow& overflow : overflows) {
|
|
reportFile << " [" << overflow.placeholderId << "] rows ";
|
|
for (size_t i = 0; i < overflow.rowIds.size(); ++i) {
|
|
if (i != 0)
|
|
reportFile << ", ";
|
|
reportFile << overflow.rowIds[i];
|
|
}
|
|
reportFile << ", " << overflow.column << ": " << overflow.value << "\n";
|
|
}
|
|
reportFile << "\n";
|
|
}
|
|
|
|
reportFile << "Per-Conv Details:\n";
|
|
for (const ConvReportEntry& entry : entries) {
|
|
reportFile << "- Conv " << entry.id << ": mode=" << entry.mode << ", strategy=" << entry.strategy
|
|
<< ", reason=" << entry.reason << "\n";
|
|
reportFile << " K=" << entry.k << ", Cout=" << entry.c << ", output_positions=" << entry.p
|
|
<< ", estimated_mvm_count=" << entry.estimatedMvmCount
|
|
<< ", estimated_reduction_vadd_count=" << entry.estimatedReductionVAddCount
|
|
<< ", estimated_output_fragments=" << entry.estimatedOutputFragments << "\n";
|
|
reportFile << " materialization_required_at_func_return=" << entry.materializationRequiredAtReturn
|
|
<< ", materialization_kind=" << entry.materializationKind
|
|
<< ", giant_collector_concat_expected=" << entry.giantCollectorConcatExpected
|
|
<< ", concat_operand_count=" << entry.concatOperandCount
|
|
<< ", collector_core=" << entry.collectorCore
|
|
<< ", full_input_broadcast_expected=" << entry.fullInputBroadcastExpected << "\n";
|
|
if (!entry.rejectedAutoStrategy.empty())
|
|
reportFile << " rejected_auto_strategy=" << entry.rejectedAutoStrategy << "\n";
|
|
if (!entry.fallbackReason.empty())
|
|
reportFile << " fallback_reason=" << entry.fallbackReason << "\n";
|
|
}
|
|
reportFile << "\n";
|
|
|
|
llvm::SmallVector<PimConvLoweringType, 8> usedStrategies;
|
|
for (const ConvReportEntry& entry : entries) {
|
|
PimConvLoweringType strategy = PimConvLoweringAuto;
|
|
for (PimConvLoweringType candidate : {
|
|
PimConvLoweringAuto,
|
|
PimConvLoweringLegacy,
|
|
PimConvLoweringDepthwise,
|
|
PimConvLoweringPackedIm2Col,
|
|
PimConvLoweringStreamedPatch,
|
|
PimConvLoweringStreamedPacked,
|
|
PimConvLoweringOutputChannelTiled,
|
|
PimConvLoweringInputKTiled,
|
|
PimConvLoweringTiled2D,
|
|
}) {
|
|
if (entry.strategy == stringifyConvLoweringStrategy(candidate)) {
|
|
strategy = candidate;
|
|
break;
|
|
}
|
|
}
|
|
if (llvm::find(usedStrategies, strategy) == usedStrategies.end())
|
|
usedStrategies.push_back(strategy);
|
|
}
|
|
|
|
reportFile << "Strategies used in this report:\n";
|
|
for (PimConvLoweringType strategy : usedStrategies)
|
|
reportFile << "- `" << stringifyConvLoweringStrategy(strategy).str() << "`: "
|
|
<< describeConvLoweringStrategy(strategy).str() << "\n";
|
|
}
|
|
|
|
static void rewriteConvLoweringReport(ArrayRef<ConvReportEntry> entries) {
|
|
std::fstream reportFile = openReportFile("conv_lowering_report");
|
|
if (!reportFile.is_open())
|
|
return;
|
|
printConvReportLegend(reportFile);
|
|
writeConvReportTable(reportFile, entries);
|
|
}
|
|
|
|
[[maybe_unused]] static FailureOr<PimConvLoweringType> resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) {
|
|
if (!useExperimentalConvImpl)
|
|
return pimConvLowering.getValue();
|
|
|
|
if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) {
|
|
convOp.emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering="
|
|
<< stringifyConvLoweringStrategy(pimConvLowering);
|
|
return failure();
|
|
}
|
|
return PimConvLoweringPackedIm2Col;
|
|
}
|
|
|
|
static ConvLoweringDecision chooseConvLoweringStrategy(const ConvGeometry& geo,
|
|
PimConvLoweringType requested,
|
|
const DistributedConvAnalysis& analysis) {
|
|
if (requested != PimConvLoweringAuto)
|
|
return {requested, "forced by compiler option", /*isAuto=*/false, "", ""};
|
|
|
|
// Transform-based convolution is intentionally not selected for this ISA:
|
|
// it would require explicit transform sequences and staging traffic on top of
|
|
// the same crossbar MVM primitive, which is not attractive here.
|
|
if (geo.isDepthwise)
|
|
return {PimConvLoweringDepthwise, "depthwise convolution", /*isAuto=*/true, "", ""};
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
|
|
return {PimConvLoweringPackedIm2Col,
|
|
"fits crossbar, packing useful, and global im2col fits budget",
|
|
/*isAuto=*/true,
|
|
"",
|
|
""};
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements > pimConvIm2colMaxElements)
|
|
return {PimConvLoweringStreamedPacked,
|
|
"fits crossbar and packing useful, but global im2col exceeds budget",
|
|
/*isAuto=*/true,
|
|
"",
|
|
""};
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
|
|
return {PimConvLoweringStreamedPatch, "fits crossbar but packing is not useful", /*isAuto=*/true, "", ""};
|
|
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
|
|
return {PimConvLoweringOutputChannelTiled,
|
|
"output channels exceed one crossbar width",
|
|
/*isAuto=*/true,
|
|
"",
|
|
""};
|
|
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) {
|
|
ConvStrategyEstimate estimate = estimateConvStrategy(geo, PimConvLoweringInputKTiled, analysis);
|
|
std::string fallbackReason = "auto rejects input-k-tiled because the reduction-heavy path is force-only for now";
|
|
if (estimate.requiresFuncReturnMaterialization && estimate.perOutputPositionReduction
|
|
&& estimate.giantCollectorConcatExpected) {
|
|
fallbackReason += "; func.return would materialize " + std::to_string(estimate.concatOperandCount)
|
|
+ " output fragments through a single collector concat after per-position reductions";
|
|
}
|
|
if (estimate.fullInputBroadcastExpected)
|
|
fallbackReason += "; the current lowering also broadcasts the padded input to many workers";
|
|
return {PimConvLoweringLegacy,
|
|
"fall back to legacy explicit-im2col for the current auto policy",
|
|
/*isAuto=*/true,
|
|
fallbackReason,
|
|
stringifyConvLoweringStrategy(PimConvLoweringInputKTiled).str()};
|
|
}
|
|
return {PimConvLoweringTiled2D, "both reduction K and output channels exceed one crossbar", /*isAuto=*/true, "", ""};
|
|
}
|
|
|
|
[[maybe_unused]] static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp,
|
|
const ConvGeometry& geo,
|
|
PimConvLoweringType strategy) {
|
|
switch (strategy) {
|
|
case PimConvLoweringAuto:
|
|
case PimConvLoweringLegacy:
|
|
return success();
|
|
case PimConvLoweringDepthwise:
|
|
if (geo.isDepthwise)
|
|
return success();
|
|
return convOp.emitOpError("forced depthwise Conv lowering requires a depthwise convolution");
|
|
case PimConvLoweringPackedIm2Col:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
|
|
return success();
|
|
return convOp.emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget");
|
|
case PimConvLoweringStreamedPatch:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
|
|
return success();
|
|
return convOp.emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar");
|
|
case PimConvLoweringStreamedPacked:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2)
|
|
return success();
|
|
return convOp.emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2");
|
|
case PimConvLoweringOutputChannelTiled:
|
|
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
|
|
return success();
|
|
return convOp.emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X");
|
|
case PimConvLoweringInputKTiled:
|
|
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize)
|
|
return success();
|
|
return convOp.emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X");
|
|
case PimConvLoweringTiled2D:
|
|
if (geo.k > geo.xbarSize && geo.c > geo.xbarSize)
|
|
return success();
|
|
return convOp.emitOpError("forced tiled-2d Conv lowering requires K > X and C > X");
|
|
}
|
|
llvm_unreachable("unknown conv lowering strategy");
|
|
}
|
|
|
|
static void reportConvLoweringDecision(ONNXConvOp convOp,
|
|
const ConvGeometry& geo,
|
|
const ConvLoweringDecision& decision,
|
|
const ConvStrategyEstimate& estimate,
|
|
int64_t batchSize,
|
|
int64_t numberOfBatches,
|
|
bool usesComputeBatch,
|
|
bool usesBatchedInstructionEmission,
|
|
std::optional<uint64_t> streamChunkPositions = std::nullopt) {
|
|
if (!pimReportConvLowering)
|
|
return;
|
|
|
|
const std::string location = summarizeLocation(convOp.getLoc());
|
|
const std::string strategy = stringifyConvLoweringStrategy(decision.strategy).str();
|
|
const std::string mode = decision.isAuto ? "auto" : "forced";
|
|
const std::string inputShape = formatShape(cast<RankedTensorType>(convOp.getX().getType()).getShape());
|
|
const std::string weightShape = formatShape(cast<RankedTensorType>(convOp.getW().getType()).getShape());
|
|
const std::string outputShape = formatShape(cast<RankedTensorType>(convOp.getY().getType()).getShape());
|
|
const std::string chunkText = streamChunkPositions ? std::to_string(*streamChunkPositions) : "-";
|
|
const std::string scbText = usesComputeBatch ? "yes" : "no";
|
|
const std::string bieText = usesBatchedInstructionEmission ? "yes" : "no";
|
|
|
|
static uint64_t reportIndex = 0;
|
|
const uint64_t currentIndex = ++reportIndex;
|
|
static std::mutex reportMutex;
|
|
static std::vector<ConvReportEntry> reportEntries;
|
|
std::lock_guard<std::mutex> lock(reportMutex);
|
|
reportEntries.push_back({
|
|
currentIndex,
|
|
location,
|
|
strategy,
|
|
mode,
|
|
inputShape,
|
|
weightShape,
|
|
outputShape,
|
|
geo.group,
|
|
geo.k,
|
|
geo.c,
|
|
geo.p,
|
|
geo.xbarSize,
|
|
geo.pack,
|
|
geo.im2colElements,
|
|
pimConvIm2colMaxElements,
|
|
chunkText,
|
|
batchSize,
|
|
numberOfBatches,
|
|
scbText,
|
|
bieText,
|
|
decision.reason,
|
|
decision.fallbackReason,
|
|
decision.rejectedAutoStrategy,
|
|
estimate.estimatedMvmCount,
|
|
estimate.estimatedReductionVAddCount,
|
|
estimate.estimatedOutputFragments,
|
|
estimate.requiresFuncReturnMaterialization ? "yes" : "no",
|
|
estimate.materializationKind,
|
|
estimate.concatOperandCount,
|
|
estimate.collectorCore,
|
|
estimate.giantCollectorConcatExpected ? "yes" : "no",
|
|
estimate.fullInputBroadcastExpected ? "yes" : "no",
|
|
});
|
|
rewriteConvLoweringReport(reportEntries);
|
|
}
|
|
|
|
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
|
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
|
if (biasType.getRank() != 1)
|
|
return bias;
|
|
|
|
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
expandedBiasType,
|
|
bias,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
});
|
|
}
|
|
|
|
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
|
|
assert(value > 0 && "expected positive value");
|
|
limit = std::min(value, limit);
|
|
for (int64_t candidate = limit; candidate >= 1; --candidate)
|
|
if (value % candidate == 0)
|
|
return candidate;
|
|
return 1;
|
|
}
|
|
|
|
static Value createZeroPaddedTensor(Value value,
|
|
RankedTensorType resultType,
|
|
ArrayRef<int64_t> lowPadValues,
|
|
ArrayRef<int64_t> highPadValues,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto valueType = cast<RankedTensorType>(value.getType());
|
|
if (valueType == resultType)
|
|
return value;
|
|
|
|
SmallVector<OpFoldResult> lowPads;
|
|
SmallVector<OpFoldResult> highPads;
|
|
lowPads.reserve(lowPadValues.size());
|
|
highPads.reserve(highPadValues.size());
|
|
for (auto lowPad : lowPadValues)
|
|
lowPads.push_back(rewriter.getIndexAttr(lowPad));
|
|
for (auto highPad : highPadValues)
|
|
highPads.push_back(rewriter.getIndexAttr(highPad));
|
|
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int64_t dim = 0, rank = resultType.getRank(); dim < rank; ++dim)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
auto zero = getOrCreateConstant(
|
|
rewriter, padOp.getOperation(), rewriter.getZeroAttr(resultType.getElementType()), resultType.getElementType());
|
|
tensor::YieldOp::create(rewriter, loc, zero);
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
return padOp.getResult();
|
|
}
|
|
|
|
static Value createConvInputPatch(Value input,
|
|
RankedTensorType patchType,
|
|
Value batchIndex,
|
|
Value channelOffset,
|
|
Value inputHeightOffset,
|
|
Value inputWidthOffset,
|
|
int64_t dilationHeight,
|
|
int64_t dilationWidth,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t patchChannels = patchType.getDimSize(1);
|
|
const int64_t kernelHeight = patchType.getDimSize(2);
|
|
const int64_t kernelWidth = patchType.getDimSize(3);
|
|
if (dilationHeight == 1 && dilationWidth == 1) {
|
|
SmallVector<OpFoldResult> offsets {batchIndex, channelOffset, inputHeightOffset, inputWidthOffset};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(patchChannels),
|
|
rewriter.getIndexAttr(kernelHeight),
|
|
rewriter.getIndexAttr(kernelWidth)};
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, patchType, input, offsets, sizes, getUnitStrides(rewriter, 4));
|
|
}
|
|
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
auto elementType = patchType.getElementType();
|
|
auto pixelType = RankedTensorType::get({1, patchChannels, 1, 1}, elementType, patchType.getEncoding());
|
|
Value patch = tensor::EmptyOp::create(rewriter, loc, patchType.getShape(), elementType);
|
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
|
Value sourceHeightOffset = affineAddConst(rewriter, loc, inputHeightOffset, kernelH * dilationHeight, anchorOp);
|
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
|
Value sourceWidthOffset = affineAddConst(rewriter, loc, inputWidthOffset, kernelW * dilationWidth, anchorOp);
|
|
SmallVector<OpFoldResult> sourceOffsets {batchIndex, channelOffset, sourceHeightOffset, sourceWidthOffset};
|
|
SmallVector<OpFoldResult> sourceSizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(patchChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
Value sourcePixel = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, pixelType, input, sourceOffsets, sourceSizes, getUnitStrides(rewriter, 4));
|
|
SmallVector<OpFoldResult> targetOffsets {
|
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(kernelH), rewriter.getIndexAttr(kernelW)};
|
|
patch = tensor::InsertSliceOp::create(
|
|
rewriter, loc, sourcePixel, patch, targetOffsets, sourceSizes, getUnitStrides(rewriter, 4));
|
|
}
|
|
}
|
|
return patch;
|
|
}
|
|
|
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
|
Type convType,
|
|
RankedTensorType gemmOutType,
|
|
RankedTensorType nhwcType,
|
|
RankedTensorType outType,
|
|
int64_t numPatches,
|
|
int64_t numChannelsOut,
|
|
int64_t packFactor,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
Location loc);
|
|
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
|
|
|
namespace depthwise {
|
|
|
|
struct Tiling {
|
|
int64_t outputMultiplier;
|
|
int64_t kernelElements;
|
|
int64_t channelsPerTile;
|
|
int64_t tileInputRows;
|
|
int64_t tileOutputChannels;
|
|
int64_t numChannelTiles;
|
|
int64_t spatialPatchesPerBatch;
|
|
int64_t totalPatches;
|
|
};
|
|
|
|
static std::optional<Tiling> computeTiling(int64_t batchSize,
|
|
int64_t numChannelsIn,
|
|
int64_t numChannelsOut,
|
|
int64_t wHeight,
|
|
int64_t wWidth,
|
|
int64_t outHeight,
|
|
int64_t outWidth) {
|
|
const int64_t kernelElements = wHeight * wWidth;
|
|
const int64_t outputMultiplier = numChannelsOut / numChannelsIn;
|
|
const int64_t xbarDim = static_cast<int64_t>(crossbarSize.getValue());
|
|
if (kernelElements <= 0 || outputMultiplier <= 0 || kernelElements > xbarDim || outputMultiplier > xbarDim)
|
|
return std::nullopt;
|
|
|
|
const int64_t maxChannelsPerTile = std::min(xbarDim / kernelElements, xbarDim / outputMultiplier);
|
|
if (maxChannelsPerTile <= 0)
|
|
return std::nullopt;
|
|
|
|
const int64_t channelsPerTile = findLargestDivisorAtMost(numChannelsIn, maxChannelsPerTile);
|
|
const int64_t tileInputRows = channelsPerTile * kernelElements;
|
|
const int64_t tileOutputChannels = channelsPerTile * outputMultiplier;
|
|
if (tileInputRows > xbarDim || tileOutputChannels > xbarDim)
|
|
return std::nullopt;
|
|
|
|
return Tiling {
|
|
outputMultiplier,
|
|
kernelElements,
|
|
channelsPerTile,
|
|
tileInputRows,
|
|
tileOutputChannels,
|
|
numChannelsIn / channelsPerTile,
|
|
outHeight * outWidth,
|
|
batchSize * outHeight * outWidth,
|
|
};
|
|
}
|
|
|
|
static Value buildPackedWeights(DenseElementsAttr wDenseAttr,
|
|
RankedTensorType wType,
|
|
const Tiling& tiling,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto packedWeightType = RankedTensorType::get(
|
|
{tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType());
|
|
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
|
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
|
|
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
|
|
|
for (int64_t tileIndex = 0; tileIndex < tiling.numChannelTiles; ++tileIndex) {
|
|
const int64_t channelBase = tileIndex * tiling.channelsPerTile;
|
|
for (int64_t localChannel = 0; localChannel < tiling.channelsPerTile; ++localChannel) {
|
|
const int64_t globalChannel = channelBase + localChannel;
|
|
for (int64_t kernelIndex = 0; kernelIndex < tiling.kernelElements; ++kernelIndex) {
|
|
const int64_t kernelH = kernelIndex / wType.getDimSize(3);
|
|
const int64_t kernelW = kernelIndex % wType.getDimSize(3);
|
|
const int64_t targetRow = localChannel * tiling.kernelElements + kernelIndex;
|
|
for (int64_t multiplierIndex = 0; multiplierIndex < tiling.outputMultiplier; ++multiplierIndex) {
|
|
const int64_t globalOutChannel = globalChannel * tiling.outputMultiplier + multiplierIndex;
|
|
const int64_t sourceFlatIndex =
|
|
((globalOutChannel * wType.getDimSize(1) * wType.getDimSize(2)) + kernelH) * wType.getDimSize(3) + kernelW;
|
|
const int64_t targetCol = localChannel * tiling.outputMultiplier + multiplierIndex;
|
|
const int64_t targetFlatIndex =
|
|
((tileIndex * tiling.tileInputRows) + targetRow) * tiling.tileOutputChannels + targetCol;
|
|
packedValues[targetFlatIndex] = sourceValues[sourceFlatIndex];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
|
}
|
|
|
|
static Value createPaddedInput(Value input,
|
|
RankedTensorType inputType,
|
|
int64_t padHeightBegin,
|
|
int64_t padHeightEnd,
|
|
int64_t padWidthBegin,
|
|
int64_t padWidthEnd,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0)
|
|
return input;
|
|
|
|
auto paddedInputType = RankedTensorType::get({inputType.getDimSize(0),
|
|
inputType.getDimSize(1),
|
|
inputType.getDimSize(2) + padHeightBegin + padHeightEnd,
|
|
inputType.getDimSize(3) + padWidthBegin + padWidthEnd},
|
|
inputType.getElementType());
|
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
|
Value padded = createZeroPaddedTensor(computeInput,
|
|
paddedInputType,
|
|
{0, 0, padHeightBegin, padWidthBegin},
|
|
{0, 0, padHeightEnd, padWidthEnd},
|
|
rewriter,
|
|
loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, padded);
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static Value createInputTile(Value input,
|
|
Value patchIndex,
|
|
Value channelTileIndex,
|
|
RankedTensorType inputTileType,
|
|
const Tiling& tiling,
|
|
int64_t strideHeight,
|
|
int64_t strideWidth,
|
|
int64_t dilationHeight,
|
|
int64_t dilationWidth,
|
|
int64_t outWidth,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp);
|
|
Value batchPatchIndex = affineModConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp);
|
|
Value outHeightIndex = affineFloorDivConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp);
|
|
Value outWidthIndex = affineModConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp);
|
|
Value inputHeightOffset =
|
|
strideHeight == 1 ? outHeightIndex : affineMulConst(rewriter, loc, outHeightIndex, strideHeight, anchorOp);
|
|
Value inputWidthOffset =
|
|
strideWidth == 1 ? outWidthIndex : affineMulConst(rewriter, loc, outWidthIndex, strideWidth, anchorOp);
|
|
Value channelOffset = tiling.channelsPerTile == 1
|
|
? channelTileIndex
|
|
: affineMulConst(rewriter, loc, channelTileIndex, tiling.channelsPerTile, anchorOp);
|
|
Value tile4D = createConvInputPatch(input,
|
|
inputTileType,
|
|
batchIndex,
|
|
channelOffset,
|
|
inputHeightOffset,
|
|
inputWidthOffset,
|
|
dilationHeight,
|
|
dilationWidth,
|
|
rewriter,
|
|
loc);
|
|
auto collapsedType = RankedTensorType::get({1, tiling.tileInputRows}, inputTileType.getElementType());
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
collapsedType,
|
|
tile4D,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
}
|
|
|
|
static Value createWeightTile(Value packedWeights,
|
|
Value channelTileIndex,
|
|
RankedTensorType packedWeightType,
|
|
const Tiling& tiling,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<OpFoldResult> offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tiling.tileInputRows),
|
|
rewriter.getIndexAttr(tiling.tileOutputChannels)};
|
|
auto sliceType =
|
|
RankedTensorType::get({1, tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType());
|
|
Value slice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, sliceType, packedWeights, offsets, sizes, getUnitStrides(rewriter, 3));
|
|
auto collapsedType =
|
|
RankedTensorType::get({tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType());
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
collapsedType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
}
|
|
|
|
static Value createBiasTile(
|
|
Value bias, Value channelTileIndex, const Tiling& tiling, PatternRewriter& rewriter, Location loc) {
|
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
|
auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType());
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value channelOffset = tiling.tileOutputChannels == 1
|
|
? channelTileIndex
|
|
: affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp);
|
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), channelOffset};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)};
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, biasTileType, bias, offsets, sizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
static Value insertOutputTile(Value rowTile,
|
|
Value rowAccumulator,
|
|
Value channelTileIndex,
|
|
const Tiling& tiling,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value channelOffset = tiling.tileOutputChannels == 1
|
|
? channelTileIndex
|
|
: affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp);
|
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), channelOffset};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)};
|
|
return tensor::InsertSliceOp::create(
|
|
rewriter, loc, rowTile, rowAccumulator, offsets, sizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
static FailureOr<Value> reconstructDepthwiseGemmRows(Value pieces,
|
|
RankedTensorType piecesType,
|
|
RankedTensorType gemmOutType,
|
|
const Tiling& tiling,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) {
|
|
auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType());
|
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, gemmOutType.getShape(), gemmOutType.getElementType());
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
|
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, tiling.totalPatches);
|
|
Value cNumChannelTiles = getOrCreateIndexConstant(rewriter, anchorOp, tiling.numChannelTiles);
|
|
|
|
auto patchLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cNumPatches,
|
|
c1,
|
|
ValueRange {outputInit},
|
|
[&](OpBuilder&,
|
|
Location nestedLoc,
|
|
Value patchIndex,
|
|
ValueRange patchIterArgs,
|
|
SmallVectorImpl<Value>& patchYielded) {
|
|
Value outputAcc = patchIterArgs.front();
|
|
Value rowInit = tensor::EmptyOp::create(rewriter, nestedLoc, rowType.getShape(), rowType.getElementType());
|
|
auto tileLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
nestedLoc,
|
|
c0,
|
|
cNumChannelTiles,
|
|
c1,
|
|
ValueRange {rowInit},
|
|
[&](OpBuilder&,
|
|
Location tileLoc,
|
|
Value channelTileIndex,
|
|
ValueRange tileIterArgs,
|
|
SmallVectorImpl<Value>& tileYielded) {
|
|
Value rowAcc = tileIterArgs.front();
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
AffineExpr d1 = getAffineDimExpr(1, context);
|
|
Value laneIndex = createOrFoldAffineApply(
|
|
rewriter, tileLoc, (d0 * tiling.totalPatches) + d1, ValueRange {channelTileIndex, patchIndex}, anchorOp);
|
|
auto rowTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, piecesType.getElementType());
|
|
SmallVector<OpFoldResult> pieceOffsets {laneIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tiling.tileOutputChannels)};
|
|
Value rowTile = tensor::ExtractSliceOp::create(
|
|
rewriter, tileLoc, rowTileType, piecesArg, pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
|
Value rowNext = insertOutputTile(rowTile, rowAcc, channelTileIndex, tiling, rewriter, tileLoc);
|
|
tileYielded.push_back(rowNext);
|
|
return success();
|
|
});
|
|
if (failed(tileLoop))
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> rowOffsets {patchIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(gemmOutType.getDimSize(1))};
|
|
Value outputNext = tensor::InsertSliceOp::create(rewriter,
|
|
nestedLoc,
|
|
tileLoop->results.front(),
|
|
outputAcc,
|
|
rowOffsets,
|
|
rowSizes,
|
|
getUnitStrides(rewriter, 2))
|
|
.getResult();
|
|
patchYielded.push_back(outputNext);
|
|
return success();
|
|
});
|
|
if (failed(patchLoop))
|
|
return failure();
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, patchLoop->results.front());
|
|
return success();
|
|
});
|
|
if (failed(collectedOp))
|
|
return failure();
|
|
return collectedOp->getResult(0);
|
|
}
|
|
|
|
static bool canUseStructuredRewrite(const ConvLoweringState& state) {
|
|
if (!getHostConstDenseElementsAttr(state.w))
|
|
return false;
|
|
|
|
auto tiling = computeTiling(state.batchSize,
|
|
state.numChannelsIn,
|
|
state.numChannelsOut,
|
|
state.wHeight,
|
|
state.wWidth,
|
|
state.outHeight,
|
|
state.outWidth);
|
|
if (!tiling)
|
|
return false;
|
|
|
|
if (!state.hasBias)
|
|
return true;
|
|
|
|
auto biasType = dyn_cast<RankedTensorType>(state.b.getType());
|
|
if (!biasType)
|
|
return false;
|
|
if (biasType.getRank() == 1)
|
|
return biasType.getDimSize(0) == state.numChannelsOut;
|
|
if (biasType.getRank() != 2)
|
|
return false;
|
|
return biasType.getDimSize(0) == 1 && biasType.getDimSize(1) == state.numChannelsOut;
|
|
}
|
|
|
|
static FailureOr<Value>
|
|
rewriteConv(Operation* convOp, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) {
|
|
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
if (!wDenseAttr) {
|
|
convOp->emitOpError("requires constant-derived weights for structured depthwise Spatial lowering");
|
|
return failure();
|
|
}
|
|
|
|
auto tiling = computeTiling(state.xType.getDimSize(0),
|
|
state.xType.getDimSize(1),
|
|
state.outType.getDimSize(1),
|
|
state.wType.getDimSize(2),
|
|
state.wType.getDimSize(3),
|
|
state.outType.getDimSize(2),
|
|
state.outType.getDimSize(3));
|
|
if (!tiling) {
|
|
convOp->emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering");
|
|
return failure();
|
|
}
|
|
|
|
Value paddedInput = createPaddedInput(state.x,
|
|
state.xType,
|
|
state.padHeightBegin,
|
|
state.padHeightEnd,
|
|
state.padWidthBegin,
|
|
state.padWidthEnd,
|
|
rewriter,
|
|
loc);
|
|
Value packedWeights = buildPackedWeights(wDenseAttr, state.wType, *tiling, rewriter, loc);
|
|
|
|
Value expandedBias;
|
|
SmallVector<Value> batchInputs {paddedInput};
|
|
if (state.hasBias) {
|
|
expandedBias = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
auto biasType = dyn_cast<RankedTensorType>(expandedBias.getType());
|
|
if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1
|
|
|| biasType.getDimSize(1) != state.outType.getDimSize(1)) {
|
|
convOp->emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering");
|
|
return failure();
|
|
}
|
|
batchInputs.push_back(expandedBias);
|
|
}
|
|
|
|
auto gemmOutType =
|
|
RankedTensorType::get({tiling->totalPatches, state.outType.getDimSize(1)}, state.outType.getElementType());
|
|
auto piecesType = RankedTensorType::get({tiling->totalPatches * tiling->numChannelTiles, tiling->tileOutputChannels},
|
|
state.outType.getElementType());
|
|
auto paddedInputType = cast<RankedTensorType>(paddedInput.getType());
|
|
auto inputTileType =
|
|
RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)},
|
|
paddedInputType.getElementType());
|
|
SmallVector<Value> batchWeights;
|
|
if (tiling->numChannelTiles == 1) {
|
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
batchWeights.push_back(createWeightTile(packedWeights,
|
|
c0,
|
|
cast<RankedTensorType>(packedWeights.getType()),
|
|
*tiling,
|
|
rewriter,
|
|
loc));
|
|
}
|
|
else {
|
|
batchWeights.push_back(packedWeights);
|
|
}
|
|
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {piecesType},
|
|
tiling->totalPatches * tiling->numChannelTiles,
|
|
batchWeights,
|
|
batchInputs,
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
auto pickInputByRank = [&](int64_t rank) -> Value {
|
|
for (Value input : args.inputs) {
|
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
|
if (inputType && inputType.getRank() == rank)
|
|
return input;
|
|
}
|
|
return Value();
|
|
};
|
|
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value patchIndex = tiling->numChannelTiles == 1
|
|
? args.lane
|
|
: affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp);
|
|
Value channelTileIndex = tiling->numChannelTiles == 1
|
|
? getOrCreateIndexConstant(rewriter, anchorOp, 0)
|
|
: affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp);
|
|
Value paddedInputArg = pickInputByRank(/*rank=*/4);
|
|
if (!paddedInputArg) {
|
|
convOp->emitOpError("structured depthwise batch body requires a rank-4 padded input block argument");
|
|
return failure();
|
|
}
|
|
|
|
Value inputTile = createInputTile(paddedInputArg,
|
|
patchIndex,
|
|
channelTileIndex,
|
|
inputTileType,
|
|
*tiling,
|
|
state.strideHeight,
|
|
state.strideWidth,
|
|
state.dilationHeight,
|
|
state.dilationWidth,
|
|
state.outType.getDimSize(3),
|
|
rewriter,
|
|
loc);
|
|
Value weightTile = tiling->numChannelTiles == 1
|
|
? args.weights.front()
|
|
: createWeightTile(args.weights.front(),
|
|
channelTileIndex,
|
|
cast<RankedTensorType>(args.weights.front().getType()),
|
|
*tiling,
|
|
rewriter,
|
|
loc);
|
|
auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType());
|
|
Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult();
|
|
if (args.inputs.size() > 1) {
|
|
Value biasArg = pickInputByRank(/*rank=*/2);
|
|
if (!biasArg) {
|
|
convOp->emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present");
|
|
return failure();
|
|
}
|
|
Value biasTile = tiling->numChannelTiles == 1 ? biasArg : createBiasTile(biasArg, channelTileIndex, *tiling, rewriter, loc);
|
|
rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult();
|
|
}
|
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tiling->tileOutputChannels)};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
|
|
return success();
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
|
|
auto nhwcType = RankedTensorType::get(
|
|
{state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)},
|
|
state.outType.getElementType());
|
|
Value collectedRows = batchOp->getResult(0);
|
|
if (tiling->numChannelTiles != 1) {
|
|
auto reconstructedRows =
|
|
reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc);
|
|
if (failed(reconstructedRows))
|
|
return failure();
|
|
collectedRows = *reconstructedRows;
|
|
}
|
|
|
|
return createCollectedConvOutput(ValueRange {collectedRows},
|
|
state.outType,
|
|
gemmOutType,
|
|
nhwcType,
|
|
state.outType,
|
|
tiling->totalPatches,
|
|
state.outType.getDimSize(1),
|
|
/*packFactor=*/1,
|
|
{},
|
|
rewriter,
|
|
loc);
|
|
}
|
|
|
|
} // namespace depthwise
|
|
|
|
namespace standard {
|
|
|
|
struct ConvGemmPlan {
|
|
int64_t patchSize;
|
|
int64_t numPatchesPerBatch;
|
|
int64_t globalNumPatches;
|
|
int64_t chunkStart;
|
|
int64_t chunkNumPatches;
|
|
int64_t maxParallelPixels;
|
|
int64_t effectiveMaxParallelPixels;
|
|
int64_t packedNumRows;
|
|
|
|
RankedTensorType im2colType;
|
|
RankedTensorType im2colRowType;
|
|
RankedTensorType gemmInputRowsType;
|
|
RankedTensorType wFlatType;
|
|
RankedTensorType wTransType;
|
|
RankedTensorType gemmOutType;
|
|
RankedTensorType gemmOutputRowsType;
|
|
RankedTensorType nhwcType;
|
|
};
|
|
|
|
static ConvGemmPlan
|
|
buildConvGemmPlan(const ConvLoweringState& state,
|
|
bool canPackWeightsAsConstants,
|
|
bool canPackBiasAsConstants,
|
|
int64_t chunkStart,
|
|
int64_t chunkNumPatches,
|
|
std::optional<int64_t> forcedPackFactor = std::nullopt);
|
|
|
|
static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (state.padHeightBegin == 0 && state.padHeightEnd == 0 && state.padWidthBegin == 0 && state.padWidthEnd == 0)
|
|
return {state.x, state.xType};
|
|
|
|
auto paddedType = RankedTensorType::get({state.batchSize,
|
|
state.numChannelsIn,
|
|
state.xHeight + state.padHeightBegin + state.padHeightEnd,
|
|
state.xWidth + state.padWidthBegin + state.padWidthEnd},
|
|
state.xType.getElementType());
|
|
auto paddedInputOp =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {paddedType}, {}, state.x, [&](Value inputArg) {
|
|
Value paddedInput = createZeroPaddedTensor(inputArg,
|
|
paddedType,
|
|
{0, 0, state.padHeightBegin, state.padWidthBegin},
|
|
{0, 0, state.padHeightEnd, state.padWidthEnd},
|
|
rewriter,
|
|
loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
|
});
|
|
return {paddedInputOp.getResult(0), paddedType};
|
|
}
|
|
|
|
static Value createPaddedRows(Value rows,
|
|
RankedTensorType rowsType,
|
|
int64_t paddedRows,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (rowsType.getDimSize(0) == paddedRows)
|
|
return rows;
|
|
|
|
auto paddedType =
|
|
RankedTensorType::get({paddedRows, rowsType.getDimSize(1)}, rowsType.getElementType(), rowsType.getEncoding());
|
|
return createZeroPaddedTensor(
|
|
rows, paddedType, {0, 0}, {paddedRows - rowsType.getDimSize(0), 0}, rewriter, loc);
|
|
}
|
|
|
|
static Value packRowsForParallelGemm(
|
|
Value rows, RankedTensorType rowsType, int64_t packFactor, PatternRewriter& rewriter, Location loc) {
|
|
if (packFactor == 1)
|
|
return rows;
|
|
|
|
const int64_t paddedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor) * packFactor;
|
|
const int64_t packedNumRows = paddedNumRows / packFactor;
|
|
const int64_t rowWidth = rowsType.getDimSize(1);
|
|
auto groupedType =
|
|
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
|
auto packedType =
|
|
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
|
|
|
Value padded = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
|
|
Value grouped = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
groupedType,
|
|
padded,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
packedType,
|
|
grouped,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
}
|
|
|
|
static Value unpackRowsFromParallelGemm(Value packedRows,
|
|
RankedTensorType packedRowsType,
|
|
int64_t unpackedRows,
|
|
int64_t rowWidth,
|
|
int64_t packFactor,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (packFactor == 1)
|
|
return packedRows;
|
|
|
|
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
|
const int64_t paddedNumRows = packedNumRows * packFactor;
|
|
auto expandedType = RankedTensorType::get(
|
|
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
|
auto paddedType =
|
|
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
|
auto unpackedType =
|
|
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
|
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
expandedType,
|
|
packedRows,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
Value padded = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
paddedType,
|
|
expanded,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
if (paddedNumRows == unpackedRows)
|
|
return padded;
|
|
|
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, padded, offsets, sizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
static Value createWeightMatrix(
|
|
Value weights, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) {
|
|
auto buildWeightMatrix = [&](Value weight) -> Value {
|
|
Value flattened = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
plan.wFlatType,
|
|
weight,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
return ONNXTransposeOp::create(rewriter, loc, plan.wTransType, flattened, rewriter.getI64ArrayAttr({1, 0}))
|
|
.getResult();
|
|
};
|
|
|
|
if (isCompileTimeComputable(weights))
|
|
return buildWeightMatrix(weights);
|
|
|
|
auto computeOp =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {plan.wTransType}, {}, ValueRange {weights}, [&](Value weight) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static Value createPaddedConvMatrix(Value matrix,
|
|
RankedTensorType sourceType,
|
|
RankedTensorType paddedType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (sourceType == paddedType)
|
|
return matrix;
|
|
return createZeroPaddedTensor(matrix,
|
|
paddedType,
|
|
{0, 0},
|
|
{paddedType.getDimSize(0) - sourceType.getDimSize(0),
|
|
paddedType.getDimSize(1) - sourceType.getDimSize(1)},
|
|
rewriter,
|
|
loc);
|
|
}
|
|
|
|
static Value createPaddedConstantMatrix(DenseElementsAttr sourceAttr,
|
|
RankedTensorType sourceType,
|
|
RankedTensorType paddedType,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<Attribute> paddedValues(
|
|
paddedType.getNumElements(), cast<Attribute>(rewriter.getZeroAttr(paddedType.getElementType())));
|
|
SmallVector<Attribute> sourceValues(sourceAttr.getValues<Attribute>());
|
|
const int64_t sourceRows = sourceType.getDimSize(0);
|
|
const int64_t sourceCols = sourceType.getDimSize(1);
|
|
const int64_t paddedCols = paddedType.getDimSize(1);
|
|
for (int64_t row = 0; row < sourceRows; ++row)
|
|
for (int64_t col = 0; col < sourceCols; ++col)
|
|
paddedValues[row * paddedCols + col] = sourceValues[row * sourceCols + col];
|
|
auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType);
|
|
}
|
|
|
|
static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr,
|
|
const ConvLoweringState& state,
|
|
int64_t paddedK,
|
|
int64_t paddedC,
|
|
PatternRewriter& rewriter) {
|
|
auto paddedType = RankedTensorType::get({paddedK, paddedC}, state.wType.getElementType());
|
|
SmallVector<Attribute> sourceValues(sourceAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> paddedValues(
|
|
paddedType.getNumElements(), cast<Attribute>(rewriter.getZeroAttr(paddedType.getElementType())));
|
|
for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) {
|
|
for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) {
|
|
for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) {
|
|
for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) {
|
|
const int64_t sourceFlatIndex =
|
|
(((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW;
|
|
const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW;
|
|
paddedValues[patchIndex * paddedC + outChannel] = sourceValues[sourceFlatIndex];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType);
|
|
}
|
|
|
|
static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
|
|
ConvGeometry geo = buildConvGeometry(state);
|
|
const int64_t xbarDim = geo.xbarSize;
|
|
const int64_t numKSlices = ceilIntegerDivide(geo.k, xbarDim);
|
|
const int64_t paddedK = numKSlices * xbarDim;
|
|
const uint64_t maxLanesPerBatch =
|
|
std::max<uint64_t>(1,
|
|
static_cast<uint64_t>(crossbarCountInCore.getValue())
|
|
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
|
|
const uint64_t rowChunkWidth = std::max<uint64_t>(
|
|
1,
|
|
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
|
|
maxLanesPerBatch,
|
|
static_cast<uint64_t>(state.outWidth)}));
|
|
const auto elementType = state.outType.getElementType();
|
|
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
if (!wDenseAttr)
|
|
return failure();
|
|
|
|
Value paddedWeight = createPaddedInputKTiledWeightConstant(wDenseAttr, state, paddedK, xbarDim, rewriter);
|
|
|
|
Value paddedBias;
|
|
RankedTensorType paddedBiasType;
|
|
if (state.hasBias) {
|
|
Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
auto biasMatrixType = cast<RankedTensorType>(biasMatrix.getType());
|
|
paddedBiasType = RankedTensorType::get({1, xbarDim}, elementType);
|
|
if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b))
|
|
paddedBias = createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter);
|
|
else
|
|
paddedBias = materializeOrComputeUnary(
|
|
biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) {
|
|
return createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc);
|
|
});
|
|
}
|
|
|
|
SmallVector<Value> chunkRows;
|
|
const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth;
|
|
chunkRows.reserve(
|
|
state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast<int64_t>(rowChunkWidth)));
|
|
for (int64_t batchIndex = 0; batchIndex < state.batchSize; ++batchIndex) {
|
|
for (int64_t outHeightIndex = 0; outHeightIndex < state.outHeight; ++outHeightIndex) {
|
|
for (int64_t outWidthChunkStart = 0; outWidthChunkStart < state.outWidth;
|
|
outWidthChunkStart += static_cast<int64_t>(rowChunkWidth)) {
|
|
const int64_t chunkNumPatches =
|
|
std::min<int64_t>(static_cast<int64_t>(rowChunkWidth), state.outWidth - outWidthChunkStart);
|
|
auto chunkRowsType = RankedTensorType::get({chunkNumPatches, state.numChannelsOut}, elementType);
|
|
auto paddedRowType = RankedTensorType::get({1, xbarDim}, elementType);
|
|
auto paddedChunkRowType = RankedTensorType::get({1, paddedK}, elementType);
|
|
auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elementType);
|
|
auto collapsedPatchType = RankedTensorType::get({1, geo.k}, elementType);
|
|
auto weightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType());
|
|
auto rowType = RankedTensorType::get({1, state.numChannelsOut}, elementType);
|
|
SmallVector<Value> inputsStorage {preparedInput.value};
|
|
if (state.hasBias)
|
|
inputsStorage.push_back(paddedBias);
|
|
ValueRange inputs(inputsStorage);
|
|
auto chunkCompute = spatial::SpatCompute::create(rewriter, loc, TypeRange {chunkRowsType}, ValueRange {paddedWeight}, inputs);
|
|
auto* block = new Block();
|
|
block->addArgument(paddedWeight.getType(), loc);
|
|
for (Value input : inputs)
|
|
block->addArgument(input.getType(), loc);
|
|
chunkCompute.getBody().push_back(block);
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
auto buildChunk = [&]() -> LogicalResult {
|
|
Value weightArg = block->getArgument(0);
|
|
Value inputArg = block->getArgument(1);
|
|
Value biasArg = state.hasBias ? block->getArgument(2) : Value();
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value cBatchIndex = getOrCreateIndexConstant(rewriter, anchorOp, batchIndex);
|
|
Value cZero = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
Value cKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
|
|
Value cOne = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
|
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
|
|
Value cInputHeightOffset =
|
|
getOrCreateIndexConstant(rewriter, anchorOp, outHeightIndex * state.strideHeight);
|
|
Value chunkRowsValue = tensor::EmptyOp::create(rewriter, loc, chunkRowsType.getShape(), elementType);
|
|
|
|
auto widthLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
cZero,
|
|
getOrCreateIndexConstant(rewriter, anchorOp, chunkNumPatches),
|
|
cOne,
|
|
ValueRange {chunkRowsValue},
|
|
[&](OpBuilder&, Location nestedLoc, Value widthIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
Value laneWithChunkOffset = affineAddConst(rewriter, nestedLoc, widthIndex, outWidthChunkStart, anchorOp);
|
|
Value inputWidthOffset = createOrFoldAffineApply(rewriter,
|
|
nestedLoc,
|
|
getAffineDimExpr(0, rewriter.getContext()) * state.strideWidth,
|
|
ValueRange {laneWithChunkOffset},
|
|
anchorOp);
|
|
Value patch = createConvInputPatch(inputArg,
|
|
patchType,
|
|
cBatchIndex,
|
|
cZero,
|
|
cInputHeightOffset,
|
|
inputWidthOffset,
|
|
state.dilationHeight,
|
|
state.dilationWidth,
|
|
rewriter,
|
|
nestedLoc);
|
|
Value patchRow = tensor::CollapseShapeOp::create(rewriter,
|
|
nestedLoc,
|
|
collapsedPatchType,
|
|
patch,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
Value paddedPatchRow = createZeroPaddedTensor(
|
|
patchRow, paddedChunkRowType, {0, 0}, {0, paddedK - geo.k}, rewriter, nestedLoc);
|
|
|
|
auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(elementType));
|
|
Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType);
|
|
auto kLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
nestedLoc,
|
|
cZero,
|
|
cKSlices,
|
|
cOne,
|
|
ValueRange {zeroRow},
|
|
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
|
|
Value acc = reduceIterArgs.front();
|
|
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar);
|
|
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
|
|
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
|
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
|
Value aTile = tensor::ExtractSliceOp::create(
|
|
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
|
|
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
|
Value bTile = extractStaticSliceOrIdentity(
|
|
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
|
|
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
|
reduceYielded.push_back(
|
|
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
|
|
return success();
|
|
});
|
|
if (failed(kLoop))
|
|
return failure();
|
|
|
|
Value reduced = kLoop->results.front();
|
|
if (state.hasBias)
|
|
reduced = spatial::SpatVAddOp::create(rewriter, nestedLoc, paddedRowType, reduced, biasArg).getResult();
|
|
|
|
Value row = reduced;
|
|
if (state.numChannelsOut != xbarDim) {
|
|
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
|
|
row = tensor::ExtractSliceOp::create(
|
|
rewriter, nestedLoc, rowType, reduced, rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
SmallVector<OpFoldResult> outputOffsets {widthIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
|
|
Value updatedRows = tensor::InsertSliceOp::create(
|
|
rewriter, nestedLoc, row, iterArgs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
|
|
yielded.push_back(updatedRows);
|
|
return success();
|
|
});
|
|
if (failed(widthLoop))
|
|
return failure();
|
|
spatial::SpatYieldOp::create(rewriter, loc, widthLoop->results.front());
|
|
return success();
|
|
};
|
|
if (failed(buildChunk())) {
|
|
rewriter.setInsertionPointAfter(chunkCompute);
|
|
rewriter.eraseOp(chunkCompute);
|
|
return failure();
|
|
}
|
|
rewriter.setInsertionPointAfter(chunkCompute);
|
|
chunkRows.push_back(chunkCompute.getResult(0));
|
|
}
|
|
}
|
|
}
|
|
|
|
auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut},
|
|
elementType);
|
|
return createCollectedConvOutput(
|
|
chunkRows, state.outType, cast<RankedTensorType>(chunkRows.front().getType()), nhwcType, state.outType, totalPatches,
|
|
state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc);
|
|
}
|
|
|
|
static Value buildPackedWeights(DenseElementsAttr wDenseAttr,
|
|
Value wTrans,
|
|
const ConvLoweringState& state,
|
|
const ConvGemmPlan& plan,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (plan.effectiveMaxParallelPixels == 1)
|
|
return wTrans;
|
|
|
|
auto packedWeightType = RankedTensorType::get(
|
|
{plan.effectiveMaxParallelPixels * plan.patchSize, plan.effectiveMaxParallelPixels * state.numChannelsOut},
|
|
state.wType.getElementType());
|
|
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
|
cast<Attribute>(rewriter.getZeroAttr(state.wType.getElementType())));
|
|
|
|
for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) {
|
|
for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) {
|
|
for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) {
|
|
for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) {
|
|
for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) {
|
|
const int64_t sourceFlatIndex =
|
|
(((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW;
|
|
const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW;
|
|
const int64_t targetRow = copyId * plan.patchSize + patchIndex;
|
|
const int64_t targetCol = copyId * state.numChannelsOut + outChannel;
|
|
packedValues[targetRow * packedWeightType.getDimSize(1) + targetCol] = sourceValues[sourceFlatIndex];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
|
}
|
|
|
|
static Value buildPackedBias(Value gemmBias,
|
|
Value biasMatrix,
|
|
DenseElementsAttr biasDenseAttr,
|
|
const ConvLoweringState& state,
|
|
const ConvGemmPlan& plan,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (!state.hasBias)
|
|
return gemmBias;
|
|
|
|
if (plan.effectiveMaxParallelPixels == 1)
|
|
return biasMatrix;
|
|
|
|
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> packedValues;
|
|
packedValues.reserve(plan.effectiveMaxParallelPixels * state.numChannelsOut);
|
|
for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId)
|
|
packedValues.append(sourceValues.begin(), sourceValues.end());
|
|
|
|
auto packedBiasType =
|
|
RankedTensorType::get({1, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.outType.getElementType());
|
|
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
|
|
}
|
|
|
|
static ConvGemmPlan
|
|
buildConvGemmPlan(const ConvLoweringState& state,
|
|
bool canPackWeightsAsConstants,
|
|
bool canPackBiasAsConstants,
|
|
int64_t chunkStart,
|
|
int64_t chunkNumPatches,
|
|
std::optional<int64_t> forcedPackFactor) {
|
|
ConvGemmPlan plan;
|
|
plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth;
|
|
plan.numPatchesPerBatch = state.outHeight * state.outWidth;
|
|
plan.globalNumPatches = state.batchSize * plan.numPatchesPerBatch;
|
|
plan.chunkStart = chunkStart;
|
|
plan.chunkNumPatches = chunkNumPatches;
|
|
const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut);
|
|
plan.maxParallelPixels = forcedPackFactor
|
|
? *forcedPackFactor
|
|
: std::max<int64_t>(1, static_cast<int64_t>(crossbarSize.getValue()) / wMaxDim);
|
|
plan.effectiveMaxParallelPixels =
|
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1;
|
|
plan.packedNumRows = ceilIntegerDivide(plan.chunkNumPatches, plan.effectiveMaxParallelPixels);
|
|
|
|
auto elemType = state.xType.getElementType();
|
|
auto outElemType = state.outType.getElementType();
|
|
plan.im2colType = RankedTensorType::get({plan.chunkNumPatches, plan.patchSize}, elemType);
|
|
plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType);
|
|
plan.gemmInputRowsType =
|
|
RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType);
|
|
plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType());
|
|
plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType());
|
|
plan.gemmOutType = RankedTensorType::get({plan.chunkNumPatches, state.numChannelsOut}, outElemType);
|
|
plan.gemmOutputRowsType =
|
|
RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType);
|
|
plan.nhwcType =
|
|
RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, outElemType);
|
|
return plan;
|
|
}
|
|
|
|
static Value createIm2colRows(const ConvLoweringState& state,
|
|
const PreparedConvInput& preparedInput,
|
|
const ConvGemmPlan& plan,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
constexpr size_t numInputs = 1;
|
|
auto im2colComputeOp =
|
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, preparedInput.value, [&](Value xArg) {
|
|
auto elemType = preparedInput.type.getElementType();
|
|
// Keep the standard im2col view of convolution, flipped so filters sit in
|
|
// B / crossbar columns:
|
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
|
// B (weights): [patchSize, cOut]
|
|
// Gemm output: [numPatches, cOut]
|
|
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, plan.im2colType.getShape(), elemType);
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
|
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
|
|
|
|
auto im2colLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cNumPatches,
|
|
c1,
|
|
ValueRange {im2colInit},
|
|
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
Value im2colAcc = iterArgs.front();
|
|
Value batchIndex =
|
|
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
|
Value batchPatchIndex =
|
|
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
|
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
|
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
|
Value inputHeightOffset =
|
|
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
|
|
Value inputWidthOffset =
|
|
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
|
|
|
|
auto patchType =
|
|
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
|
|
Value patch = createConvInputPatch(xArg,
|
|
patchType,
|
|
batchIndex,
|
|
c0,
|
|
inputHeightOffset,
|
|
inputWidthOffset,
|
|
state.dilationHeight,
|
|
state.dilationWidth,
|
|
rewriter,
|
|
nestedLoc);
|
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
|
nestedLoc,
|
|
plan.im2colRowType,
|
|
patch,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
|
|
SmallVector<OpFoldResult> rowOffsets {patchIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(plan.patchSize)};
|
|
Value next = tensor::InsertSliceOp::create(
|
|
rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
|
|
yielded.push_back(next);
|
|
return success();
|
|
});
|
|
if (failed(im2colLoop))
|
|
return failure();
|
|
|
|
Value gemmInputRows = im2colLoop->results.front();
|
|
// Pack N old im2col rows into one longer row so one GEMM can cover N
|
|
// pixels in parallel. The corresponding packed weight matrix contains N
|
|
// block-diagonal copies of W^T, and the packed output must be unpacked
|
|
// back to one row per spatial patch.
|
|
if (plan.effectiveMaxParallelPixels != 1)
|
|
gemmInputRows = packRowsForParallelGemm(gemmInputRows, plan.im2colType, plan.effectiveMaxParallelPixels, rewriter, loc);
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
|
return success();
|
|
});
|
|
|
|
assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed");
|
|
return im2colComputeOp->getResult(0);
|
|
}
|
|
|
|
static Value maybeUnpackChunkRows(Value gemmRows,
|
|
const ConvGemmPlan& plan,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (plan.effectiveMaxParallelPixels == 1)
|
|
return gemmRows;
|
|
auto unpackedType = RankedTensorType::get(
|
|
{plan.chunkNumPatches, plan.gemmOutType.getDimSize(1)}, plan.gemmOutType.getElementType(), plan.gemmOutType.getEncoding());
|
|
auto unpackCompute = createSpatCompute<1>(rewriter, loc, TypeRange {unpackedType}, {}, gemmRows, [&](Value rowsArg) {
|
|
Value unpacked = unpackRowsFromParallelGemm(rowsArg,
|
|
cast<RankedTensorType>(rowsArg.getType()),
|
|
plan.chunkNumPatches,
|
|
plan.gemmOutType.getDimSize(1),
|
|
plan.effectiveMaxParallelPixels,
|
|
rewriter,
|
|
loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, unpacked);
|
|
});
|
|
return unpackCompute.getResult(0);
|
|
}
|
|
|
|
static Value createChunkedConvRows(const ConvLoweringState& state,
|
|
const PreparedConvInput& preparedInput,
|
|
Value weightMatrix,
|
|
Value biasMatrix,
|
|
DenseElementsAttr wDenseAttr,
|
|
DenseElementsAttr biasDenseAttr,
|
|
int64_t forcedPackFactor,
|
|
uint64_t chunkPositions,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<Value> chunkRows;
|
|
const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth;
|
|
for (int64_t chunkStart = 0; chunkStart < totalPatches; chunkStart += static_cast<int64_t>(chunkPositions)) {
|
|
const int64_t chunkNumPatches = std::min<int64_t>(static_cast<int64_t>(chunkPositions), totalPatches - chunkStart);
|
|
ConvGemmPlan chunkPlan = buildConvGemmPlan(state,
|
|
static_cast<bool>(wDenseAttr),
|
|
!state.hasBias || static_cast<bool>(biasDenseAttr),
|
|
chunkStart,
|
|
chunkNumPatches,
|
|
forcedPackFactor);
|
|
Value chunkInputRows = createIm2colRows(state, preparedInput, chunkPlan, rewriter, loc);
|
|
Value chunkB = buildPackedWeights(wDenseAttr, weightMatrix, state, chunkPlan, rewriter, loc);
|
|
Value gemmBias = createZeroGemmBias(chunkPlan.gemmOutputRowsType, rewriter);
|
|
if (state.hasBias)
|
|
gemmBias = state.b;
|
|
Value chunkC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, chunkPlan, rewriter, loc);
|
|
Value chunkGemmRows = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
chunkPlan.gemmOutputRowsType,
|
|
chunkInputRows,
|
|
chunkB,
|
|
chunkC,
|
|
APFloat(1.0f),
|
|
APFloat(1.0f),
|
|
/*transA=*/0,
|
|
/*transB=*/0)
|
|
.getY();
|
|
chunkRows.push_back(maybeUnpackChunkRows(chunkGemmRows, chunkPlan, rewriter, loc));
|
|
}
|
|
|
|
if (chunkRows.size() == 1)
|
|
return chunkRows.front();
|
|
|
|
auto rowType = RankedTensorType::get({totalPatches, state.numChannelsOut}, state.outType.getElementType());
|
|
auto collectRows = createSpatCompute(rewriter, loc, TypeRange {rowType}, {}, chunkRows, [&](ValueRange rows) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, rows));
|
|
});
|
|
return collectRows.getResult(0);
|
|
}
|
|
|
|
static Value rewritePackedIm2ColConv(const ConvLoweringState& state,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
|
|
Value biasMatrix;
|
|
DenseElementsAttr biasDenseAttr;
|
|
if (state.hasBias) {
|
|
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
|
|
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
}
|
|
|
|
ConvGemmPlan plan =
|
|
buildConvGemmPlan(state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0,
|
|
state.batchSize * state.outHeight * state.outWidth);
|
|
// Prepare weight matrix W for crossbar storage:
|
|
// W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout]
|
|
Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc);
|
|
Value gemmInputRows = createIm2colRows(state, preparedInput, plan, rewriter, loc);
|
|
Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc);
|
|
Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter);
|
|
if (state.hasBias)
|
|
gemmBias = state.b;
|
|
Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc);
|
|
|
|
Value gemmRows = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
plan.gemmOutputRowsType,
|
|
gemmInputRows,
|
|
gemmB,
|
|
gemmC,
|
|
APFloat(1.0f),
|
|
APFloat(1.0f),
|
|
/*transA=*/0,
|
|
/*transB=*/0)
|
|
.getY();
|
|
|
|
return createCollectedConvOutput(ValueRange {gemmRows},
|
|
state.outType,
|
|
plan.gemmOutType,
|
|
plan.nhwcType,
|
|
state.outType,
|
|
plan.chunkNumPatches,
|
|
state.numChannelsOut,
|
|
plan.effectiveMaxParallelPixels,
|
|
distributedConsumers,
|
|
rewriter,
|
|
loc);
|
|
}
|
|
|
|
static Value rewriteStreamedConv(const ConvLoweringState& state,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
Location loc,
|
|
int64_t forcedPackFactor) {
|
|
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc);
|
|
Value biasMatrix;
|
|
DenseElementsAttr biasDenseAttr;
|
|
if (state.hasBias) {
|
|
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
|
|
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
}
|
|
|
|
ConvGemmPlan seedPlan = buildConvGemmPlan(
|
|
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0, 1, forcedPackFactor);
|
|
Value weightMatrix = createWeightMatrix(state.w, seedPlan, rewriter, loc);
|
|
ConvGeometry geo = buildConvGeometry(state);
|
|
uint64_t chunkPositions = chooseStreamChunkPositions(geo, forcedPackFactor);
|
|
Value collectedRows = createChunkedConvRows(state,
|
|
preparedInput,
|
|
weightMatrix,
|
|
biasMatrix,
|
|
wDenseAttr,
|
|
biasDenseAttr,
|
|
forcedPackFactor,
|
|
chunkPositions,
|
|
rewriter,
|
|
loc);
|
|
auto gemmOutType = cast<RankedTensorType>(collectedRows.getType());
|
|
auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut},
|
|
state.outType.getElementType());
|
|
return createCollectedConvOutput(
|
|
ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, state.outType, gemmOutType.getDimSize(0),
|
|
state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc);
|
|
}
|
|
|
|
} // namespace standard
|
|
|
|
static RankedTensorType getRowStripFragmentType(RankedTensorType tensorType, int64_t width) {
|
|
return RankedTensorType::get(
|
|
{tensorType.getDimSize(0), tensorType.getDimSize(1), 1, width}, tensorType.getElementType(), tensorType.getEncoding());
|
|
}
|
|
|
|
static SmallVector<DistributedFragmentInfo, 8> buildRowStripFragments(RankedTensorType tensorType) {
|
|
SmallVector<DistributedFragmentInfo, 8> fragments;
|
|
const int64_t height = tensorType.getDimSize(2);
|
|
const int64_t width = tensorType.getDimSize(3);
|
|
const int64_t channels = tensorType.getDimSize(1);
|
|
fragments.reserve(height);
|
|
for (int64_t row = 0; row < height; ++row) {
|
|
fragments.push_back(DistributedFragmentInfo {
|
|
{0, 0, row, 0},
|
|
{1, channels, 1, width},
|
|
{1, 1, 1, 1},
|
|
row,
|
|
});
|
|
}
|
|
return fragments;
|
|
}
|
|
|
|
static DistributedTensorInfo makeDistributedTensorInfo(Value storage, RankedTensorType logicalType) {
|
|
DistributedTensorInfo info;
|
|
info.storage = storage;
|
|
info.logicalType = logicalType;
|
|
info.fragments = buildRowStripFragments(logicalType);
|
|
info.laneCount = logicalType.getDimSize(2);
|
|
info.channels = logicalType.getDimSize(1);
|
|
info.height = logicalType.getDimSize(2);
|
|
info.width = logicalType.getDimSize(3);
|
|
return info;
|
|
}
|
|
|
|
static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr,
|
|
RankedTensorType fragmentType,
|
|
PatternRewriter& rewriter) {
|
|
auto denseType = cast<RankedTensorType>(denseAttr.getType());
|
|
SmallVector<Attribute> channelValues;
|
|
channelValues.reserve(fragmentType.getDimSize(1));
|
|
SmallVector<Attribute> flattened(denseAttr.getValues<Attribute>());
|
|
if (denseType.getRank() == 1) {
|
|
channelValues = flattened;
|
|
}
|
|
else if (denseType.getRank() == 2) {
|
|
channelValues = flattened;
|
|
}
|
|
else {
|
|
for (int64_t channel = 0; channel < denseType.getDimSize(1); ++channel)
|
|
channelValues.push_back(flattened[channel]);
|
|
}
|
|
|
|
SmallVector<Attribute> values;
|
|
values.reserve(fragmentType.getNumElements());
|
|
for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n)
|
|
for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel)
|
|
for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h)
|
|
for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w)
|
|
values.push_back(channelValues[channel]);
|
|
|
|
auto attr = DenseElementsAttr::get(fragmentType, values);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), attr, fragmentType);
|
|
}
|
|
|
|
static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter) {
|
|
auto zeroAttr = DenseElementsAttr::get(gemmResultType, rewriter.getZeroAttr(gemmResultType.getElementType()));
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), zeroAttr, gemmResultType);
|
|
}
|
|
|
|
static bool canDirectLowerRowStripConv(const ConvLoweringState& state, StringRef& failureReason) {
|
|
if (!canConsumeRowStripHwcInput(state, failureReason))
|
|
return false;
|
|
|
|
ConvGeometry geometry = buildConvGeometry(state);
|
|
if (state.numChannelsOut > geometry.xbarSize) {
|
|
failureReason = "unsupported_output_channels";
|
|
return false;
|
|
}
|
|
|
|
failureReason = "";
|
|
return true;
|
|
}
|
|
|
|
static FailureOr<Value> createRowStripPackedRows(Value rows,
|
|
const ConvLoweringState& state,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto rowsType = dyn_cast<RankedTensorType>(rows.getType());
|
|
if (!rowsType || !rowsType.hasStaticShape() || rowsType.getRank() != 2)
|
|
return failure();
|
|
|
|
if (state.batchSize != 1)
|
|
return failure();
|
|
if (state.outType.getRank() != 4 || !state.outType.hasStaticShape())
|
|
return failure();
|
|
|
|
const int64_t outHeight = state.outType.getDimSize(2);
|
|
const int64_t outWidth = state.outType.getDimSize(3);
|
|
const int64_t outChannels = state.outType.getDimSize(1);
|
|
if (rowsType.getDimSize(0) != outHeight * outWidth || rowsType.getDimSize(1) != outChannels)
|
|
return failure();
|
|
|
|
auto packedType = RankedTensorType::get({outHeight, outWidth, outChannels}, rowsType.getElementType(), rowsType.getEncoding());
|
|
auto packedRows =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {packedType}, {}, rows, [&](Value rowValues) {
|
|
Value packed = tensor::ExpandShapeOp::create(
|
|
rewriter, loc, packedType, rowValues, SmallVector<ReassociationIndices> {{0, 1}, {2}});
|
|
spatial::SpatYieldOp::create(rewriter, loc, packed);
|
|
});
|
|
return packedRows.getResult(0);
|
|
}
|
|
|
|
static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|
const ConvLoweringState& state,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto inputType = dyn_cast<RankedTensorType>(inputHwc.getType());
|
|
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 3)
|
|
return failure();
|
|
if (inputType.getDimSize(0) != state.xHeight || inputType.getDimSize(1) != state.xWidth
|
|
|| inputType.getDimSize(2) != state.numChannelsIn)
|
|
return failure();
|
|
|
|
StringRef failureReason;
|
|
if (!canDirectLowerRowStripConv(state, failureReason))
|
|
return failure();
|
|
|
|
ConvRowDemand demand = buildConvRowDemand(RowInterval {0, state.outHeight}, state);
|
|
if (!covers(demand.acquiredInputRows, demand.neededInputRows))
|
|
return failure();
|
|
|
|
ConvGeometry geometry = buildConvGeometry(state);
|
|
const int64_t xbarDim = geometry.xbarSize;
|
|
const int64_t patchSize = state.numChannelsIn * state.wHeight * state.wWidth;
|
|
const int64_t numKSlices = ceilIntegerDivide(patchSize, xbarDim);
|
|
const int64_t paddedK = numKSlices * xbarDim;
|
|
auto elementType = inputType.getElementType();
|
|
auto paddedInputType = RankedTensorType::get({state.xHeight + state.padHeightBegin + state.padHeightEnd,
|
|
state.xWidth + state.padWidthBegin + state.padWidthEnd,
|
|
state.numChannelsIn},
|
|
elementType,
|
|
inputType.getEncoding());
|
|
auto paddedPatchType =
|
|
RankedTensorType::get({state.wHeight, state.wWidth, 1}, elementType, inputType.getEncoding());
|
|
auto flatPatchType = RankedTensorType::get({state.wHeight * state.wWidth}, elementType, inputType.getEncoding());
|
|
auto rowChunkType = RankedTensorType::get({1, state.wHeight * state.wWidth}, elementType, inputType.getEncoding());
|
|
auto rowType = RankedTensorType::get({1, state.numChannelsOut}, state.outType.getElementType());
|
|
auto packedOutputType =
|
|
RankedTensorType::get({state.outHeight, state.outWidth, state.numChannelsOut}, state.outType.getElementType());
|
|
auto packedOutputSliceType =
|
|
RankedTensorType::get({1, 1, state.numChannelsOut}, state.outType.getElementType());
|
|
auto paddedRowType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType());
|
|
auto paddedPatchRowType = RankedTensorType::get({1, paddedK}, elementType, inputType.getEncoding());
|
|
auto paddedWeightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType());
|
|
auto weightDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
if (!weightDenseAttr)
|
|
return failure();
|
|
Value paddedWeights = standard::createPaddedInputKTiledWeightConstant(weightDenseAttr, state, paddedK, xbarDim, rewriter);
|
|
|
|
Value paddedBias;
|
|
if (state.hasBias) {
|
|
Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
auto biasMatrixType = cast<RankedTensorType>(biasMatrix.getType());
|
|
auto paddedBiasType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType());
|
|
if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b))
|
|
paddedBias = standard::createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter);
|
|
else
|
|
paddedBias = materializeOrComputeUnary(
|
|
biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) {
|
|
return standard::createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc);
|
|
});
|
|
}
|
|
|
|
auto paddedInputOp =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, inputHwc, [&](Value hwcInputArg) {
|
|
Value paddedInput = createZeroPaddedTensor(hwcInputArg,
|
|
paddedInputType,
|
|
{state.padHeightBegin, state.padWidthBegin, 0},
|
|
{state.padHeightEnd, state.padWidthEnd, 0},
|
|
rewriter,
|
|
loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
|
});
|
|
|
|
SmallVector<Value> batchInputs {paddedInputOp.getResult(0)};
|
|
if (state.hasBias)
|
|
batchInputs.push_back(paddedBias);
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {packedOutputType},
|
|
state.outHeight,
|
|
ValueRange {paddedWeights},
|
|
batchInputs,
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
|
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
|
|
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
|
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
|
|
Value localHeightOffset = args.lane;
|
|
Value packedRowInit =
|
|
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
|
|
auto widthLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cOutWidth,
|
|
c1,
|
|
ValueRange {packedRowInit},
|
|
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
|
|
Value localWidthOffset = widthIndex;
|
|
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
|
|
auto rowLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
widthLoc,
|
|
c0,
|
|
cNumChannels,
|
|
c1,
|
|
ValueRange {rowInit},
|
|
[&](OpBuilder&, Location rowLoc, Value channel, ValueRange rowIterArgs, SmallVectorImpl<Value>& rowYielded) {
|
|
SmallVector<OpFoldResult> patchOffsets {localHeightOffset, localWidthOffset, channel};
|
|
SmallVector<OpFoldResult> patchSizes {
|
|
rewriter.getIndexAttr(state.wHeight), rewriter.getIndexAttr(state.wWidth), rewriter.getIndexAttr(1)};
|
|
Value channelPatch = tensor::ExtractSliceOp::create(
|
|
rewriter, rowLoc, paddedPatchType, args.inputs.front(), patchOffsets, patchSizes, getUnitStrides(rewriter, 3));
|
|
Value flatPatch = tensor::CollapseShapeOp::create(
|
|
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
|
|
Value rowChunk = tensor::ExpandShapeOp::create(
|
|
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
|
|
Value flatOffset = affineMulConst(
|
|
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
|
|
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
|
|
SmallVector<OpFoldResult> rowSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
|
|
Value nextRow = tensor::InsertSliceOp::create(
|
|
rewriter, rowLoc, rowChunk, rowIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
|
|
rowYielded.push_back(nextRow);
|
|
return success();
|
|
});
|
|
if (failed(rowLoop))
|
|
return failure();
|
|
|
|
Value paddedRow = rowLoop->results.front();
|
|
if (patchSize != paddedK)
|
|
paddedRow = createZeroPaddedTensor(
|
|
paddedRow, paddedPatchRowType, {0, 0}, {0, paddedK - patchSize}, rewriter, widthLoc);
|
|
|
|
auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(state.outType.getElementType()));
|
|
Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType);
|
|
auto kLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
widthLoc,
|
|
c0,
|
|
cNumKSlices,
|
|
c1,
|
|
ValueRange {zeroRow},
|
|
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
|
|
Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
|
|
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
|
|
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
|
|
Value aTile = tensor::ExtractSliceOp::create(
|
|
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
|
|
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
|
Value bTile = extractStaticSliceOrIdentity(rewriter,
|
|
reduceLoc,
|
|
args.weights.front(),
|
|
paddedWeightTileType,
|
|
bOffsets,
|
|
bSizes,
|
|
getUnitStrides(rewriter, 2));
|
|
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
|
reduceYielded.push_back(
|
|
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
|
|
return success();
|
|
});
|
|
if (failed(kLoop))
|
|
return failure();
|
|
|
|
Value rowResult = kLoop->results.front();
|
|
if (state.hasBias)
|
|
rowResult =
|
|
spatial::SpatVAddOp::create(rewriter, widthLoc, paddedRowType, rowResult, args.inputs[1]).getResult();
|
|
|
|
Value outputRow = rowResult;
|
|
if (state.numChannelsOut != xbarDim) {
|
|
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
|
|
outputRow = tensor::ExtractSliceOp::create(
|
|
rewriter, widthLoc, rowType, rowResult, outputOffsets, outputSizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
Value outputFragment = tensor::ExpandShapeOp::create(rewriter,
|
|
widthLoc,
|
|
packedOutputSliceType,
|
|
outputRow,
|
|
SmallVector<ReassociationIndices> {{0}, {1, 2}});
|
|
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), widthIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)};
|
|
Value nextPackedRow = tensor::InsertSliceOp::create(
|
|
rewriter, widthLoc, outputFragment, widthIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 3));
|
|
widthYielded.push_back(nextPackedRow);
|
|
return success();
|
|
});
|
|
if (failed(widthLoop))
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> batchOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> batchSizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.outWidth), rewriter.getIndexAttr(state.numChannelsOut)};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, widthLoop->results.front(), args.outputs.front(), batchOffsets, batchSizes, getUnitStrides(rewriter, 3));
|
|
return success();
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return batchOp->getResult(0);
|
|
}
|
|
|
|
static FailureOr<Value> createConvRowsFromRowStripInput(const ConvLoweringState& state,
|
|
[[maybe_unused]] const ConvLoweringDecision& decision,
|
|
Value rowStripInput,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
return createConvOutputFromRowStripHwc(rowStripInput, state, rewriter, loc);
|
|
}
|
|
|
|
static Value createFragmentConstant(const DistributedTensorStep& step,
|
|
RankedTensorType fragmentType,
|
|
PatternRewriter& rewriter) {
|
|
if (step.constantKind == DistributedTensorConstantKind::PerChannel)
|
|
return createPerChannelConstantFragment(step.constantAttr, fragmentType, rewriter);
|
|
|
|
Attribute splatValue = step.constantAttr.getSplatValue<Attribute>();
|
|
return getOrCreateConstant(rewriter,
|
|
rewriter.getInsertionBlock()->getParentOp(),
|
|
DenseElementsAttr::get(fragmentType, splatValue),
|
|
fragmentType);
|
|
}
|
|
|
|
static Value createFragmentReciprocalConstant(const DistributedTensorStep& step,
|
|
RankedTensorType fragmentType,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<APFloat> values;
|
|
if (step.constantKind == DistributedTensorConstantKind::PerChannel) {
|
|
auto denseType = cast<RankedTensorType>(step.constantAttr.getType());
|
|
SmallVector<APFloat> channelValues;
|
|
for (const APFloat& value : step.constantAttr.getValues<APFloat>())
|
|
channelValues.push_back(value);
|
|
values.reserve(fragmentType.getNumElements());
|
|
for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n)
|
|
for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel)
|
|
for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h)
|
|
for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) {
|
|
APFloat reciprocal = channelValues[channel];
|
|
APFloat one(reciprocal.getSemantics(), 1);
|
|
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
|
|
assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant");
|
|
values.push_back(one);
|
|
}
|
|
(void)denseType;
|
|
}
|
|
else {
|
|
APFloat reciprocal = cast<DenseFPElementsAttr>(step.constantAttr).getSplatValue<APFloat>();
|
|
APFloat one(reciprocal.getSemantics(), 1);
|
|
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
|
|
assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant");
|
|
values.assign(fragmentType.getNumElements(), one);
|
|
}
|
|
return getOrCreateConstant(rewriter,
|
|
rewriter.getInsertionBlock()->getParentOp(),
|
|
DenseFPElementsAttr::get(fragmentType, values),
|
|
fragmentType);
|
|
}
|
|
|
|
[[maybe_unused]] static FailureOr<Value> createConvRowsForStrategy(const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto wDenseAttr = getHostConstDenseElementsAttr(state.w);
|
|
PreparedConvInput preparedInput = standard::prepareInputForIm2Col(state, rewriter, loc);
|
|
Value biasMatrix;
|
|
DenseElementsAttr biasDenseAttr;
|
|
if (state.hasBias) {
|
|
biasDenseAttr = getHostConstDenseElementsAttr(state.b);
|
|
biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc);
|
|
}
|
|
|
|
switch (decision.strategy) {
|
|
case PimConvLoweringLegacy:
|
|
case PimConvLoweringPackedIm2Col: {
|
|
standard::ConvGemmPlan plan = standard::buildConvGemmPlan(
|
|
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0,
|
|
state.batchSize * state.outHeight * state.outWidth);
|
|
Value weightMatrix = standard::createWeightMatrix(state.w, plan, rewriter, loc);
|
|
Value gemmInputRows = standard::createIm2colRows(state, preparedInput, plan, rewriter, loc);
|
|
Value gemmB = standard::buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc);
|
|
Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter);
|
|
if (state.hasBias)
|
|
gemmBias = state.b;
|
|
Value gemmC = standard::buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc);
|
|
Value gemmRows = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
plan.gemmOutputRowsType,
|
|
gemmInputRows,
|
|
gemmB,
|
|
gemmC,
|
|
APFloat(1.0f),
|
|
APFloat(1.0f),
|
|
/*transA=*/0,
|
|
/*transB=*/0)
|
|
.getY();
|
|
return standard::maybeUnpackChunkRows(gemmRows, plan, rewriter, loc);
|
|
}
|
|
case PimConvLoweringStreamedPatch:
|
|
case PimConvLoweringOutputChannelTiled:
|
|
case PimConvLoweringTiled2D:
|
|
case PimConvLoweringStreamedPacked: {
|
|
standard::ConvGemmPlan seedPlan = standard::buildConvGemmPlan(
|
|
state, static_cast<bool>(wDenseAttr), !state.hasBias || static_cast<bool>(biasDenseAttr), 0, 1,
|
|
decision.strategy == PimConvLoweringStreamedPacked ? buildConvGeometry(state).pack : 1);
|
|
Value weightMatrix = standard::createWeightMatrix(state.w, seedPlan, rewriter, loc);
|
|
ConvGeometry geo = buildConvGeometry(state);
|
|
int64_t packFactor = decision.strategy == PimConvLoweringStreamedPacked ? geo.pack : 1;
|
|
uint64_t chunkPositions = chooseStreamChunkPositions(geo, packFactor);
|
|
return standard::createChunkedConvRows(state,
|
|
preparedInput,
|
|
weightMatrix,
|
|
biasMatrix,
|
|
wDenseAttr,
|
|
biasDenseAttr,
|
|
packFactor,
|
|
chunkPositions,
|
|
rewriter,
|
|
loc);
|
|
}
|
|
default:
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
[[maybe_unused]] static FailureOr<DistributedTensorInfo> createDistributedTensorFromRows(Value rows,
|
|
RankedTensorType logicalType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t width = logicalType.getDimSize(3);
|
|
const int64_t height = logicalType.getDimSize(2);
|
|
auto rowsType = cast<RankedTensorType>(rows.getType());
|
|
auto rowSliceType =
|
|
RankedTensorType::get({width, logicalType.getDimSize(1)}, logicalType.getElementType(), rowsType.getEncoding());
|
|
auto channelWidthType =
|
|
RankedTensorType::get({logicalType.getDimSize(1), width}, logicalType.getElementType(), rowsType.getEncoding());
|
|
auto fragmentType = getRowStripFragmentType(logicalType, width);
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter, loc, TypeRange {logicalType}, height, {}, ValueRange {rows}, [&](detail::SpatComputeBatchBodyArgs args) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value rowStart = affineMulConst(rewriter, loc, args.lane, width, anchorOp);
|
|
SmallVector<OpFoldResult> rowOffsets {rowStart, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes {rewriter.getIndexAttr(width), rewriter.getIndexAttr(logicalType.getDimSize(1))};
|
|
Value rowSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, rowSliceType, args.inputs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2));
|
|
Value channelWidth = ONNXTransposeOp::create(
|
|
rewriter, loc, channelWidthType, rowSlice, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
|
Value fragment = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
fragmentType,
|
|
channelWidth,
|
|
SmallVector<ReassociationIndices> {{0, 1}, {2, 3}});
|
|
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane,
|
|
rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)),
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, fragment, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 4));
|
|
return success();
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return makeDistributedTensorInfo(batchOp->getResult(0), logicalType);
|
|
}
|
|
|
|
[[maybe_unused]] static FailureOr<DistributedTensorInfo> applyDistributedPreservingStep(const DistributedTensorInfo& inputInfo,
|
|
const DistributedTensorStep& step,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto logicalType = inputInfo.logicalType;
|
|
const int64_t width = logicalType.getDimSize(3);
|
|
auto fragmentType = getRowStripFragmentType(logicalType, width);
|
|
auto batchOp = createSpatComputeBatch(rewriter,
|
|
loc,
|
|
TypeRange {logicalType},
|
|
inputInfo.laneCount,
|
|
{},
|
|
ValueRange {inputInfo.storage},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0),
|
|
args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)),
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)};
|
|
Value fragment = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, fragmentType, args.inputs.front(), offsets, sizes, getUnitStrides(rewriter, 4));
|
|
switch (step.kind) {
|
|
case DistributedTensorOpKind::Relu:
|
|
fragment = spatial::SpatReluOp::create(rewriter, loc, fragmentType, fragment).getResult();
|
|
break;
|
|
case DistributedTensorOpKind::Sigmoid:
|
|
fragment = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, fragment).getResult();
|
|
break;
|
|
case DistributedTensorOpKind::Add: {
|
|
Value constant = createFragmentConstant(step, fragmentType, rewriter);
|
|
fragment =
|
|
spatial::SpatVAddOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Sub: {
|
|
Value constant = createFragmentConstant(step, fragmentType, rewriter);
|
|
Value lhs = step.fragmentOnLhs ? fragment : constant;
|
|
Value rhs = step.fragmentOnLhs ? constant : fragment;
|
|
fragment = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Mul: {
|
|
Value constant = createFragmentConstant(step, fragmentType, rewriter);
|
|
fragment =
|
|
spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Div: {
|
|
Value constant = createFragmentReciprocalConstant(step, fragmentType, rewriter);
|
|
fragment =
|
|
spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Conv:
|
|
return failure();
|
|
}
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, fragment, args.outputs.front(), offsets, sizes, getUnitStrides(rewriter, 4));
|
|
return success();
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return makeDistributedTensorInfo(batchOp->getResult(0), logicalType);
|
|
}
|
|
|
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
|
Type convType,
|
|
RankedTensorType gemmOutType,
|
|
RankedTensorType nhwcType,
|
|
RankedTensorType outType,
|
|
int64_t numPatches,
|
|
int64_t numChannelsOut,
|
|
int64_t packFactor,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto materializeSplatTensor = [&](DenseElementsAttr denseAttr, RankedTensorType targetType) {
|
|
Attribute splatValue = denseAttr.getSplatValue<Attribute>();
|
|
auto targetAttr = DenseElementsAttr::get(targetType, splatValue);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), targetAttr, targetType);
|
|
};
|
|
|
|
auto materializeReciprocalSplatTensor = [&](DenseFPElementsAttr denseAttr, RankedTensorType targetType) {
|
|
APFloat reciprocal = denseAttr.getSplatValue<APFloat>();
|
|
APFloat one(reciprocal.getSemantics(), 1);
|
|
[[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven);
|
|
assert(!(status & APFloat::opInvalidOp) && "distributed conv div consumer requires finite non-zero scalar");
|
|
return getOrCreateConstant(
|
|
rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseFPElementsAttr::get(targetType, one), targetType);
|
|
};
|
|
|
|
auto applyDistributedConsumers = [&](Value fragment) {
|
|
Value current = fragment;
|
|
for (const DistributedTensorStep& step : distributedConsumers) {
|
|
auto fragmentType = cast<RankedTensorType>(current.getType());
|
|
switch (step.kind) {
|
|
case DistributedTensorOpKind::Relu:
|
|
current = spatial::SpatReluOp::create(rewriter, loc, fragmentType, current).getResult();
|
|
break;
|
|
case DistributedTensorOpKind::Sigmoid:
|
|
current = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, current).getResult();
|
|
break;
|
|
case DistributedTensorOpKind::Add: {
|
|
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
|
|
current = spatial::SpatVAddOp::create(rewriter, loc, fragmentType, current, splat).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Sub: {
|
|
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
|
|
Value lhs = step.fragmentOnLhs ? current : splat;
|
|
Value rhs = step.fragmentOnLhs ? splat : current;
|
|
current = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Mul: {
|
|
Value splat = materializeSplatTensor(step.constantAttr, fragmentType);
|
|
current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, splat).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Div: {
|
|
auto reciprocalAttr = cast<DenseFPElementsAttr>(step.constantAttr);
|
|
Value reciprocal = materializeReciprocalSplatTensor(reciprocalAttr, fragmentType);
|
|
current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, reciprocal).getResult();
|
|
break;
|
|
}
|
|
case DistributedTensorOpKind::Conv:
|
|
llvm_unreachable("conv-consuming distributed chains should not materialize through createCollectedConvOutput");
|
|
}
|
|
}
|
|
return current;
|
|
};
|
|
|
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
|
SmallVector<Value> transformedRows;
|
|
transformedRows.reserve(gemmRowArgs.size());
|
|
for (Value row : gemmRowArgs)
|
|
transformedRows.push_back(applyDistributedConsumers(row));
|
|
|
|
Value gemmOut;
|
|
if (packFactor == 1) {
|
|
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows);
|
|
}
|
|
else {
|
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows);
|
|
gemmOut = standard::unpackRowsFromParallelGemm(
|
|
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
|
}
|
|
|
|
// Restore output layout:
|
|
// [numPatches, numChannelsOut]
|
|
// -> [N, Hout, Wout, Cout]
|
|
// -> [N, Cout, Hout, Wout]
|
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
nhwcType,
|
|
gemmOut,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1, 2},
|
|
{3}
|
|
});
|
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
|
});
|
|
return collectComputeOp.getResult(0);
|
|
}
|
|
|
|
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b) {
|
|
ConvLoweringState state;
|
|
state.x = x;
|
|
state.w = w;
|
|
state.b = b;
|
|
state.xType = cast<RankedTensorType>(state.x.getType());
|
|
state.wType = cast<RankedTensorType>(state.w.getType());
|
|
state.outType = cast<RankedTensorType>(convOp.getY().getType());
|
|
|
|
if (!state.xType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
|
return failure();
|
|
}
|
|
if (!state.wType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
|
return failure();
|
|
}
|
|
if (!state.outType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
|
return failure();
|
|
}
|
|
if (state.xType.getRank() != 4) {
|
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", state.xType.getRank(), {4});
|
|
return failure();
|
|
}
|
|
if (state.wType.getRank() != 4) {
|
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", state.wType.getRank(), {4});
|
|
return failure();
|
|
}
|
|
if (state.outType.getRank() != 4) {
|
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", state.outType.getRank(), {4});
|
|
return failure();
|
|
}
|
|
|
|
state.group = convOp.getGroup();
|
|
if (state.group < 1) {
|
|
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
|
return failure();
|
|
}
|
|
|
|
state.batchSize = state.xType.getDimSize(0);
|
|
state.numChannelsIn = state.xType.getDimSize(1);
|
|
state.xHeight = state.xType.getDimSize(2);
|
|
state.xWidth = state.xType.getDimSize(3);
|
|
state.numChannelsOut = state.wType.getDimSize(0);
|
|
state.wHeight = state.wType.getDimSize(2);
|
|
state.wWidth = state.wType.getDimSize(3);
|
|
state.outHeight = state.outType.getDimSize(2);
|
|
state.outWidth = state.outType.getDimSize(3);
|
|
state.hasBias = state.b && !isa<ONNXNoneOp>(state.b.getDefiningOp());
|
|
|
|
if (state.numChannelsIn % state.group != 0) {
|
|
convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group "
|
|
<< state.group << " for Spatial lowering";
|
|
return failure();
|
|
}
|
|
if (state.numChannelsOut % state.group != 0) {
|
|
convOp.emitOpError() << "requires output channels " << state.numChannelsOut << " to be divisible by group "
|
|
<< state.group << " for Spatial lowering";
|
|
return failure();
|
|
}
|
|
|
|
state.numChannelsInPerGroup = state.numChannelsIn / state.group;
|
|
state.numChannelsOutPerGroup = state.numChannelsOut / state.group;
|
|
if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) {
|
|
convOp.emitOpError() << "requires grouped conv weight input channels " << state.wType.getDimSize(1)
|
|
<< " to match input channels per group " << state.numChannelsInPerGroup
|
|
<< " for Spatial lowering";
|
|
return failure();
|
|
}
|
|
if (state.wType.getDimSize(0) != state.numChannelsOut) {
|
|
convOp.emitOpError() << "requires weight output channels " << state.wType.getDimSize(0)
|
|
<< " to match result channels " << state.numChannelsOut << " for Spatial lowering";
|
|
return failure();
|
|
}
|
|
|
|
const auto stridesAttr = convOp.getStrides();
|
|
const auto dilationsAttr = convOp.getDilations();
|
|
const auto padsAttr = convOp.getPads();
|
|
|
|
if (stridesAttr && stridesAttr->size() != 2) {
|
|
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
|
return failure();
|
|
}
|
|
if (dilationsAttr && dilationsAttr->size() != 2) {
|
|
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
|
return failure();
|
|
}
|
|
if (padsAttr && padsAttr->size() != 4) {
|
|
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
|
return failure();
|
|
}
|
|
|
|
state.strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
|
|
state.strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
|
|
state.dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
|
|
state.dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
|
|
state.padHeightBegin = 0;
|
|
state.padHeightEnd = 0;
|
|
state.padWidthBegin = 0;
|
|
state.padWidthEnd = 0;
|
|
|
|
if (padsAttr) {
|
|
state.padHeightBegin = getI64Attr(*padsAttr, 0);
|
|
state.padWidthBegin = getI64Attr(*padsAttr, 1);
|
|
state.padHeightEnd = getI64Attr(*padsAttr, 2);
|
|
state.padWidthEnd = getI64Attr(*padsAttr, 3);
|
|
return state;
|
|
}
|
|
|
|
const auto autoPad = convOp.getAutoPad();
|
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
const int64_t effectiveKernelH = (state.wHeight - 1) * state.dilationHeight + 1;
|
|
const int64_t effectiveKernelW = (state.wWidth - 1) * state.dilationWidth + 1;
|
|
const int64_t totalPadH =
|
|
std::max(static_cast<int64_t>(0), (state.outHeight - 1) * state.strideHeight + effectiveKernelH - state.xHeight);
|
|
const int64_t totalPadW =
|
|
std::max(static_cast<int64_t>(0), (state.outWidth - 1) * state.strideWidth + effectiveKernelW - state.xWidth);
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
state.padHeightBegin = totalPadH / 2;
|
|
state.padHeightEnd = totalPadH - state.padHeightBegin;
|
|
state.padWidthBegin = totalPadW / 2;
|
|
state.padWidthEnd = totalPadW - state.padWidthBegin;
|
|
}
|
|
else {
|
|
state.padHeightEnd = totalPadH / 2;
|
|
state.padHeightBegin = totalPadH - state.padHeightEnd;
|
|
state.padWidthEnd = totalPadW / 2;
|
|
state.padWidthBegin = totalPadW - state.padWidthEnd;
|
|
}
|
|
return state;
|
|
}
|
|
|
|
if (autoPad != "NOTSET" && autoPad != "VALID") {
|
|
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
|
return failure();
|
|
}
|
|
|
|
return state;
|
|
}
|
|
|
|
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) {
|
|
return analyzeConvLoweringState(convOp, convOpAdaptor.getX(), convOpAdaptor.getW(), convOpAdaptor.getB());
|
|
}
|
|
|
|
static FailureOr<ConvLoweringState> analyzeConvLoweringState(spatial::SpatConv2DPlanOp planOp) {
|
|
ConvLoweringState state;
|
|
state.x = planOp.getInput();
|
|
state.w = planOp.getWeight();
|
|
state.b = planOp.getBias() ? planOp.getBias() : Value();
|
|
state.xType = dyn_cast<RankedTensorType>(state.x.getType());
|
|
state.wType = dyn_cast<RankedTensorType>(state.w.getType());
|
|
state.outType = dyn_cast<RankedTensorType>(planOp.getOutput().getType());
|
|
|
|
if (!state.xType || !state.wType || !state.outType)
|
|
return planOp.emitOpError("requires ranked tensor input, weight, and output"), failure();
|
|
if (!state.xType.hasStaticShape() || !state.wType.hasStaticShape() || !state.outType.hasStaticShape())
|
|
return planOp.emitOpError("requires static input, weight, and output shapes"), failure();
|
|
if (state.xType.getRank() != 4 || state.wType.getRank() != 4 || state.outType.getRank() != 4)
|
|
return planOp.emitOpError("requires rank-4 input, weight, and output tensors"), failure();
|
|
|
|
state.group = planOp.getGroup();
|
|
if (state.group < 1)
|
|
return planOp.emitOpError("requires group >= 1"), failure();
|
|
|
|
state.batchSize = state.xType.getDimSize(0);
|
|
state.numChannelsIn = state.xType.getDimSize(1);
|
|
state.xHeight = state.xType.getDimSize(2);
|
|
state.xWidth = state.xType.getDimSize(3);
|
|
state.numChannelsOut = state.wType.getDimSize(0);
|
|
state.wHeight = state.wType.getDimSize(2);
|
|
state.wWidth = state.wType.getDimSize(3);
|
|
state.outHeight = state.outType.getDimSize(2);
|
|
state.outWidth = state.outType.getDimSize(3);
|
|
state.hasBias = static_cast<bool>(planOp.getBias());
|
|
|
|
if (state.numChannelsIn % state.group != 0 || state.numChannelsOut % state.group != 0)
|
|
return planOp.emitOpError("requires input and output channels divisible by group"), failure();
|
|
|
|
state.numChannelsInPerGroup = state.numChannelsIn / state.group;
|
|
state.numChannelsOutPerGroup = state.numChannelsOut / state.group;
|
|
if (state.wType.getDimSize(1) != state.numChannelsInPerGroup)
|
|
return planOp.emitOpError("requires grouped conv weight channels to match input channels per group"), failure();
|
|
|
|
auto pads = planOp.getPads();
|
|
auto strides = planOp.getStrides();
|
|
auto dilations = planOp.getDilations();
|
|
if (pads.size() != 4 || strides.size() != 2 || dilations.size() != 2)
|
|
return planOp.emitOpError("requires 4 pads, 2 strides, and 2 dilations"), failure();
|
|
|
|
state.padHeightBegin = pads[0];
|
|
state.padWidthBegin = pads[1];
|
|
state.padHeightEnd = pads[2];
|
|
state.padWidthEnd = pads[3];
|
|
state.strideHeight = strides[0];
|
|
state.strideWidth = strides[1];
|
|
state.dilationHeight = dilations[0];
|
|
state.dilationWidth = dilations[1];
|
|
return state;
|
|
}
|
|
|
|
static FailureOr<PimConvLoweringType> resolveRequestedConvLoweringStrategy(Operation* op) {
|
|
if (!useExperimentalConvImpl)
|
|
return pimConvLowering.getValue();
|
|
|
|
if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) {
|
|
op->emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering="
|
|
<< stringifyConvLoweringStrategy(pimConvLowering);
|
|
return failure();
|
|
}
|
|
return PimConvLoweringPackedIm2Col;
|
|
}
|
|
|
|
static LogicalResult verifyForcedConvLoweringStrategy(Operation* op,
|
|
const ConvGeometry& geo,
|
|
PimConvLoweringType strategy) {
|
|
switch (strategy) {
|
|
case PimConvLoweringAuto:
|
|
case PimConvLoweringLegacy:
|
|
return success();
|
|
case PimConvLoweringDepthwise:
|
|
if (geo.isDepthwise)
|
|
return success();
|
|
return op->emitOpError("forced depthwise Conv lowering requires a depthwise convolution");
|
|
case PimConvLoweringPackedIm2Col:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements)
|
|
return success();
|
|
return op->emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget");
|
|
case PimConvLoweringStreamedPatch:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize)
|
|
return success();
|
|
return op->emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar");
|
|
case PimConvLoweringStreamedPacked:
|
|
if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2)
|
|
return success();
|
|
return op->emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2");
|
|
case PimConvLoweringOutputChannelTiled:
|
|
if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize)
|
|
return success();
|
|
return op->emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X");
|
|
case PimConvLoweringInputKTiled:
|
|
if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize)
|
|
return success();
|
|
return op->emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X");
|
|
case PimConvLoweringTiled2D:
|
|
if (geo.k > geo.xbarSize && geo.c > geo.xbarSize)
|
|
return success();
|
|
return op->emitOpError("forced tiled-2d Conv lowering requires K > X and C > X");
|
|
}
|
|
llvm_unreachable("unknown conv lowering strategy");
|
|
}
|
|
|
|
static FailureOr<Value> lowerDenseSelectedConvPlan(Operation* op,
|
|
const ConvLoweringState& state,
|
|
PimConvLoweringType strategy,
|
|
PatternRewriter& rewriter,
|
|
Location loc);
|
|
|
|
static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent,
|
|
Value groupX,
|
|
Value groupW,
|
|
Value groupB,
|
|
RankedTensorType groupOutType);
|
|
|
|
static FailureOr<Value> buildConvValueForStrategy(Operation* op,
|
|
Location loc,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
const DistributedConvAnalysis& analysis,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter);
|
|
|
|
static FailureOr<Value> buildGroupedConvValue(Operation* op,
|
|
Location loc,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
PatternRewriter& rewriter);
|
|
|
|
static FailureOr<Value> lowerGroupedSelectedConvPlan(Operation* op,
|
|
const ConvLoweringState& state,
|
|
PimConvLoweringType strategy,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
ConvLoweringDecision decision {strategy, "", false, "", ""};
|
|
return buildGroupedConvValue(op, loc, state, decision, rewriter);
|
|
}
|
|
|
|
static FailureOr<Value> lowerDenseSelectedConvPlan(Operation* op,
|
|
const ConvLoweringState& state,
|
|
PimConvLoweringType strategy,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
DistributedConvAnalysis analysis;
|
|
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
|
|
analysis.barrierDetail = "selected dense layout";
|
|
ConvLoweringDecision decision {strategy, "", false, "", ""};
|
|
return buildConvValueForStrategy(op, loc, state, decision, analysis, {}, rewriter);
|
|
}
|
|
|
|
static FailureOr<Value> buildConvValueForStrategy(Operation* op,
|
|
Location loc,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
const DistributedConvAnalysis& analysis,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter) {
|
|
(void)analysis;
|
|
const ConvGeometry geo = buildConvGeometry(state);
|
|
switch (decision.strategy) {
|
|
case PimConvLoweringDepthwise: {
|
|
return depthwise::rewriteConv(op, state, rewriter, loc);
|
|
}
|
|
case PimConvLoweringLegacy:
|
|
case PimConvLoweringPackedIm2Col: {
|
|
return standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, loc);
|
|
}
|
|
case PimConvLoweringStreamedPatch:
|
|
case PimConvLoweringOutputChannelTiled:
|
|
case PimConvLoweringTiled2D: {
|
|
return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, /*forcedPackFactor=*/1);
|
|
}
|
|
case PimConvLoweringInputKTiled: {
|
|
return standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, loc);
|
|
}
|
|
case PimConvLoweringStreamedPacked: {
|
|
return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, geo.pack);
|
|
}
|
|
case PimConvLoweringAuto:
|
|
break;
|
|
}
|
|
op->emitOpError("unexpected auto strategy at Conv lowering dispatch");
|
|
return failure();
|
|
}
|
|
|
|
static LogicalResult
|
|
createConvValueForStrategy(ONNXConvOp convOp,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
const DistributedConvAnalysis& analysis,
|
|
ArrayRef<DistributedTensorStep> distributedConsumers,
|
|
PatternRewriter& rewriter,
|
|
FailureOr<Value>& result) {
|
|
result = buildConvValueForStrategy(convOp, convOp.getLoc(), state, decision, analysis, distributedConsumers, rewriter);
|
|
if (failed(result))
|
|
return failure();
|
|
|
|
const ConvGeometry geo = buildConvGeometry(state);
|
|
const ConvStrategyEstimate estimate = estimateConvStrategy(geo, decision.strategy, analysis);
|
|
switch (decision.strategy) {
|
|
case PimConvLoweringDepthwise:
|
|
reportConvLoweringDecision(
|
|
convOp, geo, decision, estimate, /*batchSize=*/geo.p, /*numberOfBatches=*/1, /*usesComputeBatch=*/true,
|
|
/*usesBatchedInstructionEmission=*/true, std::nullopt);
|
|
return success();
|
|
case PimConvLoweringLegacy:
|
|
case PimConvLoweringPackedIm2Col:
|
|
reportConvLoweringDecision(
|
|
convOp, geo, decision, estimate, /*batchSize=*/geo.pack, /*numberOfBatches=*/1, /*usesComputeBatch=*/true,
|
|
/*usesBatchedInstructionEmission=*/true, std::nullopt);
|
|
return success();
|
|
case PimConvLoweringStreamedPatch:
|
|
case PimConvLoweringOutputChannelTiled:
|
|
case PimConvLoweringTiled2D: {
|
|
uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1);
|
|
const int64_t batches = ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions));
|
|
reportConvLoweringDecision(convOp,
|
|
geo,
|
|
decision,
|
|
estimate,
|
|
/*batchSize=*/1,
|
|
batches,
|
|
/*usesComputeBatch=*/true,
|
|
/*usesBatchedInstructionEmission=*/true,
|
|
chunkPositions);
|
|
return success();
|
|
}
|
|
case PimConvLoweringInputKTiled: {
|
|
const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize);
|
|
const uint64_t maxLanesPerBatch =
|
|
std::max<uint64_t>(1,
|
|
static_cast<uint64_t>(crossbarCountInCore.getValue())
|
|
/ static_cast<uint64_t>(std::max<int64_t>(1, numKSlices * 4)));
|
|
const uint64_t rowChunkWidth = std::max<uint64_t>(
|
|
1,
|
|
std::min<uint64_t>({chooseStreamChunkPositions(geo, /*packFactor=*/1),
|
|
maxLanesPerBatch,
|
|
static_cast<uint64_t>(state.outWidth)}));
|
|
const int64_t batches =
|
|
state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast<int64_t>(rowChunkWidth));
|
|
reportConvLoweringDecision(convOp,
|
|
geo,
|
|
decision,
|
|
estimate,
|
|
/*batchSize=*/1,
|
|
batches,
|
|
/*usesComputeBatch=*/false,
|
|
/*usesBatchedInstructionEmission=*/false,
|
|
rowChunkWidth);
|
|
return success();
|
|
}
|
|
case PimConvLoweringStreamedPacked: {
|
|
uint64_t chunkPositions = chooseStreamChunkPositions(geo, geo.pack);
|
|
const int64_t batches = ceilIntegerDivide(geo.p, static_cast<int64_t>(chunkPositions));
|
|
reportConvLoweringDecision(convOp,
|
|
geo,
|
|
decision,
|
|
estimate,
|
|
/*batchSize=*/geo.pack,
|
|
batches,
|
|
/*usesComputeBatch=*/true,
|
|
/*usesBatchedInstructionEmission=*/true,
|
|
chunkPositions);
|
|
return success();
|
|
}
|
|
case PimConvLoweringAuto:
|
|
break;
|
|
}
|
|
return convOp.emitOpError("unexpected auto strategy at Conv lowering dispatch");
|
|
}
|
|
|
|
static LogicalResult
|
|
rewriteSelectedConv(ONNXConvOp convOp,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
const DistributedConvAnalysis& analysis,
|
|
PatternRewriter& rewriter) {
|
|
FailureOr<Value> result = failure();
|
|
if (failed(createConvValueForStrategy(convOp, state, decision, analysis, analysis.steps, rewriter, result)))
|
|
return failure();
|
|
|
|
if (!analysis.hasLocalConsumers()) {
|
|
rewriter.replaceOp(convOp, *result);
|
|
return success();
|
|
}
|
|
|
|
assert(analysis.replacementOp && "conv rewrite expects a replacement op");
|
|
rewriter.replaceOp(analysis.replacementOp, *result);
|
|
for (auto it = analysis.steps.rbegin(); it != analysis.steps.rend(); ++it)
|
|
if (it->op != analysis.replacementOp)
|
|
rewriter.eraseOp(it->op);
|
|
rewriter.eraseOp(convOp);
|
|
return success();
|
|
}
|
|
|
|
[[maybe_unused]] static LogicalResult
|
|
rewriteUngroupedConv(ONNXConvOp convOp,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
const DistributedConvAnalysis& analysis,
|
|
PatternRewriter& rewriter) {
|
|
return rewriteSelectedConv(convOp, state, decision, analysis, rewriter);
|
|
}
|
|
|
|
static LogicalResult
|
|
rewriteGroupedConv(ONNXConvOp convOp,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
PatternRewriter& rewriter);
|
|
|
|
static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent,
|
|
Value groupX,
|
|
Value groupW,
|
|
Value groupB,
|
|
RankedTensorType groupOutType);
|
|
|
|
static ConvLoweringState makeGroupedConvLoweringState(
|
|
const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) {
|
|
ConvLoweringState state = parent;
|
|
state.x = groupX;
|
|
state.w = groupW;
|
|
state.b = groupB;
|
|
state.xType = cast<RankedTensorType>(groupX.getType());
|
|
state.wType = cast<RankedTensorType>(groupW.getType());
|
|
state.outType = groupOutType;
|
|
state.batchSize = state.xType.getDimSize(0);
|
|
state.numChannelsIn = state.xType.getDimSize(1);
|
|
state.xHeight = state.xType.getDimSize(2);
|
|
state.xWidth = state.xType.getDimSize(3);
|
|
state.numChannelsOut = state.wType.getDimSize(0);
|
|
state.wHeight = state.wType.getDimSize(2);
|
|
state.wWidth = state.wType.getDimSize(3);
|
|
state.outHeight = state.outType.getDimSize(2);
|
|
state.outWidth = state.outType.getDimSize(3);
|
|
state.group = 1;
|
|
state.numChannelsInPerGroup = state.numChannelsIn;
|
|
state.numChannelsOutPerGroup = state.numChannelsOut;
|
|
state.hasBias = static_cast<bool>(groupB);
|
|
return state;
|
|
}
|
|
|
|
static FailureOr<Value> buildGroupedConvValue(Operation* op,
|
|
Location loc,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
PatternRewriter& rewriter) {
|
|
SmallVector<Value> xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, loc);
|
|
SmallVector<Value> wSlices = sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, loc);
|
|
SmallVector<Value> bSlices;
|
|
if (state.hasBias) {
|
|
auto biasType = cast<RankedTensorType>(state.b.getType());
|
|
int64_t biasAxis = -1;
|
|
if (biasType.getRank() == 1)
|
|
biasAxis = 0;
|
|
else if (biasType.getRank() == 2)
|
|
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
|
else {
|
|
op->emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
|
<< biasType.getRank();
|
|
return failure();
|
|
}
|
|
bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, loc);
|
|
}
|
|
|
|
if (xSlices.size() != static_cast<size_t>(state.group) || wSlices.size() != static_cast<size_t>(state.group)
|
|
|| (state.hasBias && bSlices.size() != static_cast<size_t>(state.group))) {
|
|
op->emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<Value> groupResults;
|
|
groupResults.reserve(state.group);
|
|
auto groupOutType = RankedTensorType::get(
|
|
{state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType());
|
|
for (int64_t groupId = 0; groupId < state.group; groupId++) {
|
|
Value groupX = xSlices[groupId];
|
|
Value groupW = wSlices[groupId];
|
|
Value groupB = state.hasBias ? bSlices[groupId] : Value();
|
|
ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType);
|
|
DistributedConvAnalysis groupAnalysis;
|
|
groupAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv;
|
|
groupAnalysis.barrierDetail = "grouped convolution still materializes densely";
|
|
FailureOr<Value> groupResult =
|
|
buildConvValueForStrategy(op, loc, groupState, decision, groupAnalysis, {}, rewriter);
|
|
if (failed(groupResult))
|
|
return failure();
|
|
groupResults.push_back(*groupResult);
|
|
}
|
|
|
|
if (llvm::all_of(groupResults, isCompileTimeComputable))
|
|
return createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
|
|
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
|
});
|
|
return concatCompute.getResult(0);
|
|
}
|
|
|
|
[[maybe_unused]] static LogicalResult
|
|
rewriteGroupedConv(ONNXConvOp convOp,
|
|
const ConvLoweringState& state,
|
|
const ConvLoweringDecision& decision,
|
|
PatternRewriter& rewriter) {
|
|
FailureOr<Value> result = buildGroupedConvValue(convOp.getOperation(), convOp.getLoc(), state, decision, rewriter);
|
|
if (failed(result))
|
|
return failure();
|
|
rewriter.replaceOp(convOp, *result);
|
|
return success();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|
ONNXConvOpAdaptor convOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(convOp, convOpAdaptor);
|
|
if (failed(state))
|
|
return failure();
|
|
SmallVector<int64_t> pads {
|
|
state->padHeightBegin, state->padWidthBegin, state->padHeightEnd, state->padWidthEnd};
|
|
SmallVector<int64_t> strides {state->strideHeight, state->strideWidth};
|
|
SmallVector<int64_t> dilations {state->dilationHeight, state->dilationWidth};
|
|
Value bias = state->hasBias ? convOpAdaptor.getB() : Value();
|
|
auto convPlan = spatial::SpatConv2DPlanOp::create(rewriter,
|
|
convOp.getLoc(),
|
|
convOp.getY().getType(),
|
|
convOpAdaptor.getX(),
|
|
convOpAdaptor.getW(),
|
|
bias,
|
|
rewriter.getDenseI64ArrayAttr(pads),
|
|
rewriter.getDenseI64ArrayAttr(strides),
|
|
rewriter.getDenseI64ArrayAttr(dilations),
|
|
rewriter.getI64IntegerAttr(state->group),
|
|
rewriter.getStringAttr("nchw"));
|
|
rewriter.replaceOp(convOp, convPlan.getResult());
|
|
return success();
|
|
}
|
|
|
|
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
|
|
|
LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp) {
|
|
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
|
|
if (failed(state))
|
|
return failure();
|
|
|
|
if (state->group != 1 || state->batchSize != 1)
|
|
return failure();
|
|
if (state->outType.getRank() != 4 || !state->outType.hasStaticShape())
|
|
return failure();
|
|
|
|
FailureOr<PimConvLoweringType> requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation());
|
|
if (failed(requestedStrategy))
|
|
return failure();
|
|
|
|
DistributedConvAnalysis analysis;
|
|
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
|
|
analysis.barrierDetail = "selected row-strip layout";
|
|
ConvGeometry geometry = buildConvGeometry(*state);
|
|
ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis);
|
|
if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)
|
|
&& *requestedStrategy == PimConvLoweringAuto) {
|
|
decision = {PimConvLoweringLegacy,
|
|
"depthwise auto fallback when structured depthwise lowering is not representable",
|
|
/*isAuto=*/true,
|
|
"",
|
|
""};
|
|
}
|
|
if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy)))
|
|
return failure();
|
|
|
|
switch (decision.strategy) {
|
|
case PimConvLoweringLegacy:
|
|
case PimConvLoweringPackedIm2Col:
|
|
case PimConvLoweringStreamedPatch:
|
|
case PimConvLoweringOutputChannelTiled:
|
|
case PimConvLoweringTiled2D:
|
|
case PimConvLoweringStreamedPacked:
|
|
return success();
|
|
case PimConvLoweringAuto:
|
|
case PimConvLoweringDepthwise:
|
|
case PimConvLoweringInputKTiled:
|
|
return failure();
|
|
}
|
|
llvm_unreachable("unknown conv lowering strategy");
|
|
}
|
|
|
|
LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp) {
|
|
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
|
|
if (failed(state))
|
|
return failure();
|
|
|
|
StringRef failureReason;
|
|
return canDirectLowerRowStripConv(*state, failureReason) ? success() : failure();
|
|
}
|
|
|
|
FailureOr<Value>
|
|
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
|
|
std::optional<Value> rowStripInput,
|
|
bool emitRowStripLayout,
|
|
PatternRewriter& rewriter) {
|
|
FailureOr<ConvLoweringState> state = analyzeConvLoweringState(planOp);
|
|
if (failed(state))
|
|
return failure();
|
|
|
|
FailureOr<PimConvLoweringType> requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation());
|
|
if (failed(requestedStrategy))
|
|
return failure();
|
|
|
|
DistributedConvAnalysis analysis;
|
|
analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer;
|
|
analysis.barrierDetail = emitRowStripLayout ? "selected row-strip layout" : "selected dense layout";
|
|
ConvGeometry geometry = buildConvGeometry(*state);
|
|
ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis);
|
|
if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)
|
|
&& *requestedStrategy == PimConvLoweringAuto) {
|
|
decision = {PimConvLoweringLegacy,
|
|
"depthwise auto fallback when structured depthwise lowering is not representable",
|
|
/*isAuto=*/true,
|
|
"",
|
|
""};
|
|
}
|
|
if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy)))
|
|
return failure();
|
|
|
|
if (emitRowStripLayout) {
|
|
if (rowStripInput) {
|
|
if (failed(canConsumeAndProduceRowStrip(planOp)))
|
|
return planOp.emitOpError("selected row-strip input/output layout is not supported for this Conv plan"), failure();
|
|
return createConvRowsFromRowStripInput(*state, decision, *rowStripInput, rewriter, planOp.getLoc());
|
|
}
|
|
if (failed(canLowerConvPlanToRowStrip(planOp)))
|
|
return planOp.emitOpError("selected row-strip layout is not supported for this Conv plan"), failure();
|
|
FailureOr<Value> rows = createConvRowsForStrategy(*state, decision, rewriter, planOp.getLoc());
|
|
if (failed(rows))
|
|
return failure();
|
|
FailureOr<Value> packedRows = createRowStripPackedRows(*rows, *state, rewriter, planOp.getLoc());
|
|
if (failed(packedRows))
|
|
return planOp.emitOpError("failed to pack Conv rows into the selected row-strip physical layout"), failure();
|
|
return *packedRows;
|
|
}
|
|
|
|
if (decision.strategy == PimConvLoweringDepthwise)
|
|
return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
|
|
if (state->group != 1)
|
|
return lowerGroupedSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
|
|
return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|