#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallVector.h" #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ConvToGemm : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; static DenseElementsAttr getDenseConstantAttr(Value value) { if (auto constantOp = value.getDefiningOp()) return dyn_cast(constantOp.getValue()); if (auto constantOp = value.getDefiningOp()) return dyn_cast_or_null(constantOp.getValueAttr()); return nullptr; } static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast(arr[idx]).getInt(); } static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) return bias; auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType()); return tensor::ExpandShapeOp::create(rewriter, loc, expandedBiasType, bias, SmallVector { {0, 1} }); } static Value createPaddedRows(Value tensorValue, RankedTensorType tensorType, int64_t paddedRows, ConversionPatternRewriter& rewriter, Location loc) { if (tensorType.getDimSize(0) == paddedRows) return tensorValue; auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType()); SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), rewriter.getIndexAttr(0)}; auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads); auto* padBlock = new Block(); for (int i = 0; i < 2; i++) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = arith::ConstantOp::create( rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType())); tensor::YieldOp::create(rewriter, loc, zero.getResult()); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } static Value buildPackedWeight(DenseElementsAttr wDenseAttr, Value wTrans, RankedTensorType wType, int64_t numChannelsIn, int64_t numChannelsOut, int64_t wHeight, int64_t wWidth, int64_t patchSize, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { if (packFactor == 1) return wTrans; auto packedWeightType = RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType()); SmallVector sourceValues(wDenseAttr.getValues()); SmallVector packedValues(packedWeightType.getNumElements(), cast(rewriter.getZeroAttr(wType.getElementType()))); for (int64_t copyId = 0; copyId < packFactor; copyId++) { for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) { for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) { for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) { for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) { const int64_t sourceFlatIndex = (((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW; const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW; const int64_t targetRow = copyId * patchSize + patchIndex; const int64_t targetCol = copyId * numChannelsOut + outChannel; packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex]; } } } } } auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr); } static Value buildPackedBias(bool hasBias, Value gemmBias, Value biasMatrix, DenseElementsAttr biasDenseAttr, RankedTensorType outType, int64_t numChannelsOut, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { if (!hasBias) return gemmBias; if (packFactor == 1) return biasMatrix; SmallVector sourceValues(biasDenseAttr.getValues()); SmallVector packedValues; packedValues.reserve(packFactor * numChannelsOut); for (int64_t copyId = 0; copyId < packFactor; copyId++) packedValues.append(sourceValues.begin(), sourceValues.end()); auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); } static Value createIm2colRowComputes(Value x, RankedTensorType xType, RankedTensorType im2colType, RankedTensorType im2colRowType, RankedTensorType gemmInputRowsType, int64_t batchSize, int64_t numChannelsIn, int64_t xHeight, int64_t xWidth, int64_t wHeight, int64_t wWidth, int64_t padHeightBegin, int64_t padHeightEnd, int64_t padWidthBegin, int64_t padWidthEnd, int64_t strideHeight, int64_t strideWidth, int64_t dilationHeight, int64_t dilationWidth, int64_t outWidth, int64_t patchSize, int64_t numPatches, int64_t numPatchesPerBatch, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { auto elemType = xType.getElementType(); constexpr size_t numInputs = 1; const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); auto im2colComputeOp = createSpatCompute(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { Value paddedInput = xArg; // Pad input with zeros if needed: // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padHeightBegin), rewriter.getIndexAttr(padWidthBegin)}; SmallVector highPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padHeightEnd), rewriter.getIndexAttr(padWidthEnd)}; auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); auto* padBlock = new Block(); for (int i = 0; i < 4; i++) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); tensor::YieldOp::create(rewriter, loc, zero.getResult()); rewriter.setInsertionPointAfter(padOp); paddedInput = padOp.getResult(); } // Build im2col [numPatches, patchSize] incrementally to keep the IR small // until the late PIM unrolling step. Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); rewriter.setInsertionPointToStart(im2colLoop.getBody()); Value patchIndex = im2colLoop.getInductionVar(); Value im2colAcc = im2colLoop.getRegionIterArgs().front(); Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); SmallVector offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numChannelsIn), rewriter.getIndexAttr(wHeight), rewriter.getIndexAttr(wWidth)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(dilationHeight), rewriter.getIndexAttr(dilationWidth)}; auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); Value row = tensor::CollapseShapeOp::create(rewriter, loc, im2colRowType, patch, SmallVector { {0}, {1, 2, 3} }); SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value updatedIm2col = tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); scf::YieldOp::create(rewriter, loc, updatedIm2col); rewriter.setInsertionPointAfter(im2colLoop); Value im2col = im2colLoop.getResult(0); Value gemmInputRows = im2col; if (packFactor != 1) { const int64_t paddedNumPatches = packedNumRows * packFactor; auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc); Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, loc, groupedType, paddedIm2col, SmallVector { {0, 1}, {2} }); gemmInputRows = tensor::CollapseShapeOp::create(rewriter, loc, packedType, groupedIm2col, SmallVector { {0}, {1, 2} }); } spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); }); return im2colComputeOp.getResult(0); } static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, RankedTensorType nhwcType, RankedTensorType outType, int64_t numPatches, int64_t numChannelsOut, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); const int64_t paddedNumPatches = packedNumRows * packFactor; auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { Value gemmOut; if (packFactor == 1) { gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); } else { auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, packedOutput, SmallVector { {0}, {1, 2} }); Value paddedOutput = tensor::CollapseShapeOp::create(rewriter, loc, paddedType, expandedOutput, SmallVector { {0, 1}, {2} }); gemmOut = paddedOutput; if (paddedNumPatches != numPatches) { SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides); } } // Restore to NCHW layout: // [numPatches, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut] // -> [1, numChannelsOut, outHeight, outWidth] Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, loc, nhwcType, gemmOut, SmallVector { {0, 1, 2}, {3} }); Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); spatial::SpatYieldOp::create(rewriter, loc, nchwOut); }); return collectComputeOp.getResult(0); } } // namespace LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = convOp.getLoc(); Value x = convOpAdaptor.getX(); Value w = convOpAdaptor.getW(); Value b = convOpAdaptor.getB(); auto xType = cast(x.getType()); auto wType = cast(w.getType()); auto outType = cast(convOp.getY().getType()); assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); assert("Only support 2D convolution" && xType.getRank() == 4); // We need to understand what is group assert("Only support group=1" && convOp.getGroup() == 1); const int64_t batchSize = xType.getDimSize(0); const int64_t numChannelsIn = xType.getDimSize(1); const int64_t xHeight = xType.getDimSize(2); const int64_t xWidth = xType.getDimSize(3); const int64_t numChannelsOut = wType.getDimSize(0); const int64_t wHeight = wType.getDimSize(2); const int64_t wWidth = wType.getDimSize(3); const int64_t outHeight = outType.getDimSize(2); const int64_t outWidth = outType.getDimSize(3); // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) const auto stridesAttr = convOp.getStrides(); const auto dilationsAttr = convOp.getDilations(); const auto padsAttr = convOp.getPads(); const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; int64_t padHeightBegin = 0; int64_t padHeightEnd = 0; int64_t padWidthBegin = 0; int64_t padWidthEnd = 0; if (padsAttr) { padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); } else { // Compute padding from auto_pad attribute const auto autoPad = convOp.getAutoPad(); if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; const int64_t totalPadH = std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); const int64_t totalPadW = std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); if (autoPad == "SAME_UPPER") { padHeightBegin = totalPadH / 2; padHeightEnd = totalPadH - padHeightBegin; padWidthBegin = totalPadW / 2; padWidthEnd = totalPadW - padWidthBegin; } else { // SAME_LOWER padHeightEnd = totalPadH / 2; padHeightBegin = totalPadH - padHeightEnd; padWidthEnd = totalPadW / 2; padWidthBegin = totalPadW - padWidthEnd; } } // "NOTSET" or "VALID" -> all pads stay 0 } // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): // A (im2col): [numPatches, patchSize] -- one row per output spatial position // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // Gemm output: [numPatches, cOut] const int64_t patchSize = numChannelsIn * wHeight * wWidth; const int64_t numPatchesPerBatch = outHeight * outWidth; const int64_t numPatches = batchSize * numPatchesPerBatch; auto elemType = xType.getElementType(); auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); auto rowType = RankedTensorType::get({1, patchSize}, elemType); auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t wMaxDim = std::max(patchSize, numChannelsOut); const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); auto wDenseAttr = getDenseConstantAttr(w); // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] Value wFlat = tensor::CollapseShapeOp::create(rewriter, loc, wFlatType, w, SmallVector { {0}, {1, 2, 3} }); Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); // Pass bias through directly; Gemm handles rank-1 C canonicalization. bool hasB = !isa(b.getDefiningOp()); Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (hasB) { gemmBias = b; biasDenseAttr = getDenseConstantAttr(b); biasMatrix = expandBiasIfNeeded(b, rewriter, loc); } const bool canPackWeightsAsConstants = static_cast(wDenseAttr); const bool canPackBiasAsConstants = !hasB || static_cast(biasDenseAttr); const int64_t effectiveMaxParallelPixels = (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; // Keep the standard im2col view of convolution: // A (im2col): [numPatches, patchSize] -- one row per output spatial position // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // and optionally repack several old rows into one GEMM row to use the available crossbar size better. // // We want to process N pixels at the same time. Instead of doing N separate operations // of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix // containing N copies of W^T and concatenate N im2col rows into one longer row: // A_packed: [ceil(numPatches / N), N * patchSize] // B_packed: [N * patchSize, N * cOut] // Y_packed: [ceil(numPatches / N), N * cOut] const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType); auto gemmOutputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); Value gemmInputRows = createIm2colRowComputes(x, xType, im2colType, rowType, gemmInputRowsType, batchSize, numChannelsIn, xHeight, xWidth, wHeight, wWidth, padHeightBegin, padHeightEnd, padWidthBegin, padWidthEnd, strideHeight, strideWidth, dilationHeight, dilationWidth, outWidth, patchSize, numPatches, numPatchesPerBatch, effectiveMaxParallelPixels, rewriter, loc); Value gemmB = buildPackedWeight(wDenseAttr, wTrans, wType, numChannelsIn, numChannelsOut, wHeight, wWidth, patchSize, effectiveMaxParallelPixels, rewriter, loc); Value gemmC = buildPackedBias( hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); Value gemmRows = ONNXGemmOp::create(rewriter, loc, gemmOutputRowsType, gemmInputRows, gemmB, gemmC, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); rewriter.replaceOp(convOp, createCollectedConvOutput(ValueRange {gemmRows}, convOp.getType(), gemmOutType, nhwcType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc)); return success(); } void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir