This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvLoweringState {
|
||||
Value x;
|
||||
Value w;
|
||||
Value b;
|
||||
RankedTensorType xType;
|
||||
RankedTensorType wType;
|
||||
RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct ConvLoweringDecision {
|
||||
PimConvLoweringType strategy;
|
||||
std::string reason;
|
||||
@@ -108,19 +56,6 @@ struct PreparedConvInput {
|
||||
RankedTensorType type;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
struct ConvStrategyEstimate {
|
||||
uint64_t estimatedMvmCount = 0;
|
||||
uint64_t estimatedReductionVAddCount = 0;
|
||||
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc);
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
||||
|
||||
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
|
||||
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
|
||||
return estimate;
|
||||
}
|
||||
|
||||
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
static std::string formatShape(ArrayRef<int64_t> dims) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
|
||||
int64_t inputHeight,
|
||||
int64_t kernelH,
|
||||
int64_t strideH,
|
||||
int64_t dilationH,
|
||||
int64_t padTop) {
|
||||
const int64_t rawBegin = outputRows.begin * strideH - padTop;
|
||||
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
|
||||
}
|
||||
|
||||
static bool covers(RowInterval acquired, RowInterval needed) {
|
||||
return acquired.begin <= needed.begin && acquired.end >= needed.end;
|
||||
}
|
||||
|
||||
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
|
||||
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = neededInputRows;
|
||||
demand.acquiredInputRows = neededInputRows;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
return demand;
|
||||
}
|
||||
|
||||
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
|
||||
if (state.batchSize != 1) {
|
||||
failureReason = "unsupported_batch";
|
||||
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
|
||||
rewriteConvLoweringReport(reportEntries);
|
||||
}
|
||||
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
|
||||
});
|
||||
}
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
|
||||
assert(value > 0 && "expected positive value");
|
||||
limit = std::min(value, limit);
|
||||
|
||||
Reference in New Issue
Block a user