fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user