diff --git a/onnx-mlir b/onnx-mlir index 82018d7..eb54c2a 160000 --- a/onnx-mlir +++ b/onnx-mlir @@ -1 +1 @@ -Subproject commit 82018d7ce59c94bfbe9479b16538224969fa45a0 +Subproject commit eb54c2afc46d00c6b196d1f275b6bfee17e12f69 diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index c0c48ac..7e0ad60 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,7 +3,9 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_onnx_mlir_library(OMONNXToSpatial + Math/Gemm.hpp Math/Gemm.cpp + Math/Conv.hpp Math/Conv.cpp Math/ExperimentalConv.cpp Math/ExperimentalGemm.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index 08e6837..b34aebb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -1,583 +1,247 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" -#include -#include -#include -#include +#include -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Conv.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; -using namespace std; namespace onnx_mlir { -// NOTE: -// This might be useful to re-implement this considering for loops. -// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; +LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = convOp.getLoc(); + Value x = convOpAdaptor.getX(); + Value w = convOpAdaptor.getW(); + Value b = convOpAdaptor.getB(); -/** - * @brief A momentary representation of a core, to be used within the tiling of - * a convolution operation. - */ -class Core { -public: - Core(const size_t coreId, ConversionPatternRewriter& rewriter) - : coreId(coreId), rewriter(rewriter) {} + auto xType = cast(x.getType()); + auto wType = cast(w.getType()); + auto outType = cast(convOp.getY().getType()); - /** - * @brief Add a MVM operation to the core. - * - * @param inputTile The input tile to the MVM operation. - * @param xbarIndex The index of the crossbar weight to use. - * @param outputTileId The id of the output tile. - * @param mvmOutType The result's shape. - * @return Value The result of the MVM operation. - */ - Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { - // Use the inputTile as the reference location for the MVM operation. - Location loc = inputTile.getLoc(); + assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); + assert("Only support 2D convolution" && xType.getRank() == 4); + assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1); - // Move the insertion point to the end of the block. - rewriter.setInsertionPointToEnd(block.get()); + // We need to understand what is group + assert("Only support group=1" && convOp.getGroup() == 1); - // Add the inputTile to the block arguments, and to the operands. - Value operand = operandMap.lookupOrNull(inputTile); - if (not operand) { - operand = block->addArgument(inputTile.getType(), loc); - operands.push_back(inputTile); - operandMap.map(inputTile, operand); - } + const int64_t numChannelsIn = xType.getDimSize(1); + const int64_t xHeight = xType.getDimSize(2); + const int64_t xWidth = xType.getDimSize(3); + const int64_t numChannelsOut = wType.getDimSize(0); + const int64_t wHeight = wType.getDimSize(2); + const int64_t wWidth = wType.getDimSize(3); + const int64_t outHeight = outType.getDimSize(2); + const int64_t outWidth = outType.getDimSize(3); - // TODO: Compute the output type using the matrix, and check if `mvmOutType` - // is correct. + // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) + auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast(arr[idx]).getInt(); }; - // Construct the MVM operation - Value result = rewriter.create(loc, mvmOutType, xbarIndex, operand); + const auto stridesAttr = convOp.getStrides(); + const auto dilationsAttr = convOp.getDilations(); + const auto padsAttr = convOp.getPads(); - // Since we are within the same core and no computation can happen in - // paralllel, we can just apply a linear reduction in case we have multiple - // MVM operations for the same outputTile. - auto lastMVM = outputTileToMVM.find(outputTileId); + const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1; + const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1; + const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1; + const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1; - // If an entry for this outputTile already exists, apply reduction. - if (lastMVM != outputTileToMVM.end()) { - // MVM results should have the same type for reduction. - assert(lastMVM->second.getType() == result.getType()); - result = rewriter.create(loc, mvmOutType, lastMVM->second, result); - } + int64_t padHeightBegin = 0; + int64_t padHeightEnd = 0; + int64_t padWidthBegin = 0; + int64_t padWidthEnd = 0; - outputTileToMVM[outputTileId] = result; - return result; + if (padsAttr) { + padHeightBegin = getI64(*padsAttr, 0); + padWidthBegin = getI64(*padsAttr, 1); + padHeightEnd = getI64(*padsAttr, 2); + padWidthEnd = getI64(*padsAttr, 3); } + else { + // Compute padding from auto_pad attribute + const auto autoPad = convOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; + const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; + const int64_t totalPadH = + std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); + const int64_t totalPadW = + std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); - /** - * @brief Mark a result as remappable, and return a shared pointer to it. - * - * This function marks a result as remappable, and returns a shared pointer to - * it. We need to keep track of these values to generate the YieldOp at a - * later stage. - * - * @param result A result to track, for later remapping. - * @return shared_ptr A shared pointer to the result. - */ - shared_ptr makeResultRemappable(Value result) { - // Verify that the result is present in the block. - assert(result.getDefiningOp()->getBlock() == block.get()); - - shared_ptr remappableResult = make_shared(result); - - resultsToRemap.push_back(remappableResult); - results.push_back(result); - - return remappableResult; - } - - /** - * @brief Add a remappable operand to the core, to merge partial results - * inter-core. - * - * @param remappableOperand The operand to add. - * @return Value The block argument representing the operand. - */ - Value addRemappableOperand(std::shared_ptr operand) { - // Check that the operand is not already there. - assert(not operandMap.contains(*operand)); - - Value argument = block->addArgument(operand->getType(), operand->getLoc()); - remappableOperands.push_back(operand); - return argument; - } - - /** - * @brief Generate a spatial::SpatWeightedCompute operation from the core. - * - * @param loc The location of the operation. - * @return spatial::SpatWeightedCompute - */ - spatial::SpatWeightedCompute createWComputeOp(Location loc) { - // Get the shape of the results. - SmallVector resultTypes; - for (const auto& value : results) - resultTypes.push_back(value.getType()); - - // Create the WComputeOp, with non-remappable operands only. - wcomputeOp = rewriter.create(loc, resultTypes, xbarWeights, operands); - - // Add the body to the WComputeOp. - Block* releasedBlock = block.release(); - wcomputeOp.getBody().push_back(releasedBlock); - - // Add the `yieldOp` at the end, with the results. - rewriter.setInsertionPointToEnd(releasedBlock); - rewriter.create(loc, results); - - return wcomputeOp; - } - - /** - * @brief Remap the results to the WComputeOp results. - */ - void remapResults() { - // Remap all the results to the WComputeOp results. - assert(resultsToRemap.size() == wcomputeOp->getNumResults()); - for (size_t i = 0; i < resultsToRemap.size(); i++) - *resultsToRemap[i] = wcomputeOp.getResult(i); - } - - void addRemappedOperands() { - // Insert the remappableOperands (which were remapped in - // `addRemappableOperand` of another Core) - for (auto remappedValue : remappableOperands) - wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); - - // Update the wcomputeOp operandSegmentSize - incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast(remappableOperands.size())); - } - - size_t addXbarWeight(Value weight) { - assert(!isXbarsFull()); - xbarWeights.push_back(weight); - return xbarWeights.size() - 1; - } - - bool isXbarsFull() { - assert(xbarWeights.size() <= crossbarCountInCore); - return xbarWeights.size() == crossbarCountInCore; - } - - bool isCoreEmpty() { return block->empty(); } - - void dump() { - // Print the coreId - llvm::outs() << "Core " << coreId << ":\n"; - // Print the weights - llvm::outs() << "Xbar Weights:\n"; - for (auto weight : xbarWeights) - weight.dump(); - // Print the operands - llvm::outs() << "Operands:\n"; - for (auto operand : operands) - llvm::outs() << operand << "\n"; - - // Dump the body block - for (auto& op : block->getOperations()) - op.dump(); - - // Print the results - llvm::outs() << "Results:\n"; - for (auto result : results) - llvm::outs() << result << "\n"; - } - - const size_t coreId; - -private: - ConversionPatternRewriter& rewriter; - - // Should these be set instead? But I need to keep the order - vector operands; - vector> remappableOperands; - - vector results; - vector> resultsToRemap; - - // Maps from input tiles to the block operand - IRMapping operandMap; - - // Map from outputTileId to MVM operation producing it - unordered_map outputTileToMVM; - - vector xbarWeights; - - unique_ptr block = make_unique(); - - spatial::SpatWeightedCompute wcomputeOp; -}; - -struct ONNXConvOpTile : public OpConversionPattern { - ONNXConvOpTile(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - struct Producer_t { - Value value; - shared_ptr core; - }; - - LogicalResult - matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); - ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); - ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); - ShapedType yShape = mlir::cast(conv.getY().getType()); - - size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; - unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); - unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); - - auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) - return rewriter.notifyMatchFailure(conv, padUnpackError.value()); - - // TODO: Pad value at beginning and end of each dimension could be - // different. We should handle this case. - - // MapOperations mapOperation = MapOperations::None; - // - // // If we have just one user, and it is an activation funcion (or more in - // // general a mapping operation) just inline it in the computeOps - // auto firstUserOp = *conv->getUsers().begin(); - // if (conv->hasOneUse()) { - // mapOperation = mlirOpToMapOperationEnum(firstUserOp); - // - // if (mapOperation == MapOperations::ONNXSoftmaxOp) { - // return rewriter.notifyMatchFailure( - // conv, "Softmax not supported as activation for convolutions."); - // } - // } - - size_t input_h = GET_IMAGE_HEIGHT(xShape); - size_t input_w = GET_IMAGE_WIDTH(xShape); - size_t output_h = GET_IMAGE_HEIGHT(yShape); - size_t output_w = GET_IMAGE_WIDTH(yShape); - size_t krn_h = GET_KERNEL_HEIGHT(wShape); - size_t krn_w = GET_KERNEL_WIDTH(wShape); - - Location loc = conv.getLoc(); - - size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); - size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; - size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); - size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; - - // Tile the input tensor - // Input tiles need to be indexed by: - // a. Channel Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: inputTiles[channelTile][x][y] - // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) - SmallVector>> inputTiles( - inputTileCount, SmallVector>(input_w, SmallVector(input_h))); - - auto resolveErrorOpt = resolveImgInputTiles( - convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter); - if (resolveErrorOpt.has_value()) - return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); - - SmallVector strides = SmallVector(4, rewriter.getIndexAttr(1)); - SmallVector offsets = SmallVector(4, rewriter.getIndexAttr(0)); - SmallVector sizes = SmallVector {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - - // Tile the weight tensor - // Weight tiles need to be indexed by: - // a. Filter Tile - // b. Channel Tile - // c. Kernel `x` position - // d. Kernel `y` position - // For example: weightTiles[filterTile][channelTile][x][y] - // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) - SmallVector>>> weightTiles( - outputTileCount, - SmallVector>>(inputTileCount, - SmallVector>(krn_w, SmallVector(krn_h)))); - strides = SmallVector(4, rewriter.getIndexAttr(1)); - offsets = SmallVector(4, rewriter.getIndexAttr(0)); - sizes = {rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - for (size_t i = 0; i < outputTileCount; i++) { - if (i == outputTileCount - 1 && outputTileRemainder != 0) - sizes[0] = rewriter.getIndexAttr(outputTileRemainder); - sizes[1] = rewriter.getIndexAttr(crossbarSize); - offsets[0] = rewriter.getIndexAttr(i * crossbarSize); - for (size_t j = 0; j < inputTileCount; j++) { - if (j == inputTileCount - 1 && inputTileRemainder != 0) - sizes[1] = rewriter.getIndexAttr(inputTileRemainder); - for (size_t x = 0; x < krn_w; x++) { - for (size_t y = 0; y < krn_h; y++) { - offsets[1] = rewriter.getIndexAttr(j * crossbarSize); - offsets[2] = rewriter.getIndexAttr(x); - offsets[3] = rewriter.getIndexAttr(y); - weightTiles[i][j][x][y] = - rewriter.create(loc, convAdaptor.getW(), offsets, sizes, strides); - } - } + if (autoPad == "SAME_UPPER") { + padHeightBegin = totalPadH / 2; + padHeightEnd = totalPadH - padHeightBegin; + padWidthBegin = totalPadW / 2; + padWidthEnd = totalPadW - padWidthBegin; + } + else { // SAME_LOWER + padHeightEnd = totalPadH / 2; + padHeightBegin = totalPadH - padHeightEnd; + padWidthEnd = totalPadW / 2; + padWidthBegin = totalPadW - padWidthEnd; } } - - /* Distribute the computation among many compute cores - * Try to compute in-core the computation for each output tile, and reduce - * over as few cores as possible - */ - - // Tile the output tensor - // Output tiles need to be indexed by: - // a. Filter Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: outputTiles[filterTile][x][y] - // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) - SmallVector>>> outputTiles( - outputTileCount, - SmallVector>>(output_w, SmallVector>(output_h, nullptr))); - - size_t replicationFactor; - if (!conv->hasAttr(REPLICATION_ATTR_NAME)) - replicationFactor = 1; - else - replicationFactor = conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); - // producers[outTile][out_x][out_y][producerIndex] - vector>>> producers = vector>>>( - outputTileCount, - vector>>(output_w, vector>(output_h, vector()))); - - // Schedule in cores - size_t coreId = 0; - vector> curCores(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) - curCores[i] = make_shared(coreId++, rewriter); - - vector> cores; - - const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor); - - for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { - for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { - - RankedTensorType mvmOutType = - RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, bShape.getElementType()); - - for (size_t outTile = 0; outTile < outputTileCount; outTile++) { - - if (outTile == outputTileCount - 1 && outputTileRemainder != 0) - mvmOutType = mvmOutType.clone({1, static_cast(outputTileRemainder), 1, 1}); - - for (size_t inTile = 0; inTile < inputTileCount; inTile++) { - - vector xbarIndexes(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) - xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]); - - size_t out_x = 0; - for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { - size_t out_y = 0; - - // I use `replicationFactor` cores. I divide the input_w into - // `replicationFactor` slices, and each slice is distributed to a - // core. `coreIndex` is the index of the core that will be used - // for this slice - size_t coreIndex = in_x / replicationSliceSize; - assert(coreIndex < replicationFactor); - - for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { - // Adjust the input based on the kernel - int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x; - int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y; - - // Check if we are within the input image - if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) { - out_y++; - continue; - } - - size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y; - auto mvm = curCores[coreIndex]->addMVM( - inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType); - - producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]}); - - out_y++; - } - out_x++; - } - - // Computations for these crossbars are done, check if the cores - // crossbars are fully used. If full, swap with new core - for (size_t i = 0; i < replicationFactor; i++) { - if (curCores[i]->isXbarsFull()) { - cores.emplace_back(std::move(curCores[i])); - curCores[i] = make_shared(coreId++, rewriter); - } - } - } - } - } - } - - for (auto& curCore : curCores) - if (curCore->isCoreEmpty() == false) - cores.emplace_back(std::move(curCore)); - curCores.clear(); - // Now, do the reduction of each output pixel tile - for (size_t outTile = 0; outTile < outputTileCount; outTile++) { - for (size_t out_x = 0; out_x < output_w; out_x++) { - for (size_t out_y = 0; out_y < output_h; out_y++) { - // First, check if some producers are within the same core. If this is - // true, `Core::addMVM` have already done the reduction within-core. - // This means that we only need to consider the last producer for that - // core. - - std::unordered_map withinCoreReducedProducers; - for (auto producer : producers[outTile][out_x][out_y]) - withinCoreReducedProducers[producer.core->coreId] = producer; - - // Now, we need to apply inter-core reduction - - // Base case with one producer - if (withinCoreReducedProducers.size() == 1) { - // TODO: Add the bias and apply mapping (if present) - - auto singleProducer = withinCoreReducedProducers.begin()->second; - // Use last producer as the final result - auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value); - outputTiles[outTile][out_x][out_y] = reducedValue; - continue; - } - - // TODO: This is a linear reduction, not a tree reduction. We can do - // better: a tree reduction would make more computations happen in - // parallel. - - Producer_t lastProducer = withinCoreReducedProducers.begin()->second; - - auto it = withinCoreReducedProducers.begin(); - it++; - while (it != withinCoreReducedProducers.end()) { - - Producer_t curProducer = it->second; - - shared_ptr core1; - shared_ptr core2; - Value core1Value; - Value core2Value; - - auto lastProducerCoreId = lastProducer.core->coreId; - auto curProducerCoreId = curProducer.core->coreId; - - assert(lastProducerCoreId != curProducerCoreId - && "We should have already applied within-core reduction, how " - "could we have same cores here?"); - - // Sort the cores by coreId - if (curProducerCoreId < lastProducerCoreId) { - core1 = curProducer.core; - core1Value = curProducer.value; - core2 = lastProducer.core; - core2Value = lastProducer.value; - } - else { - core1 = lastProducer.core; - core1Value = lastProducer.value; - core2 = curProducer.core; - core2Value = curProducer.value; - } - - auto newCoreRes = core1->makeResultRemappable(core1Value); - auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); - - rewriter.setInsertionPointAfterValue(core2Value); - Value vaddRes = rewriter.create( - core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); - - lastProducer = {vaddRes, core2}; - - it++; - } - - // TODO: Add the bias and apply mapping (if present) - - // Use last producer as the final result - auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); - outputTiles[outTile][out_x][out_y] = reducedValue; - } - } - } - - // Now, we need to turn the cores into a spatial::SpatWeightedCompute. - rewriter.setInsertionPointAfter(conv); - spatial::SpatWeightedCompute lastWComputeOp; - for (auto& core : cores) { - lastWComputeOp = core->createWComputeOp(loc); - core->remapResults(); - rewriter.setInsertionPointAfter(lastWComputeOp); - } - - for (auto& core : cores) - core->addRemappedOperands(); - - // Set the insertion point after the last WComputeOp. - rewriter.setInsertionPointAfter(lastWComputeOp); - SmallVector tilesToConcat; - tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); - for (size_t outX = 0; outX < output_h; outX++) - for (size_t outY = 0; outY < output_w; outY++) - for (size_t outTile = 0; outTile < outputTileCount; outTile++) - tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); - - Value outputImage = rewriter.create(loc, conv.getY().getType(), tilesToConcat); - - // Value outputImage = - // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); - - // If no mapping (activation) was applied, just replace ConvOp - // if (mapOperation == MapOperations::None) { - // rewriter.replaceOp(conv, outputImage); - // } else { - // // If mapping was applied, erase ConvOp and replace the mapping op - // rewriter.eraseOp(conv); - // rewriter.replaceOp(firstUserOp, outputImage); - // } - - return success(); + // "NOTSET" or "VALID" -> all pads stay 0 } -}; -void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); + // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns + // Gemm output: [numPatches, cOut] + const int64_t patchSize = numChannelsIn * wHeight * wWidth; + const int64_t numPatches = outHeight * outWidth; + + auto elemType = xType.getElementType(); + + // Pad input with zeros if needed: + // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] + if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { + const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; + const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; + auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); + auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType); + SmallVector lowPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightBegin), + rewriter.getIndexAttr(padWidthBegin)}; + SmallVector highPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightEnd), + rewriter.getIndexAttr(padWidthEnd)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, x, lowPads, highPads); + auto* padBlock = new Block(); + for (int i = 0; i < 4; i++) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + x = padOp.getResult(); + } + + // Build im2col [numPatches, patchSize]: + // For each output position (oh, ow), extract the patch from x + auto rowType = RankedTensorType::get({1, patchSize}, elemType); + SmallVector im2colRows; + im2colRows.reserve(numPatches); + + for (int64_t oh = 0; oh < outHeight; oh++) { + for (int64_t ow = 0; ow < outWidth; ow++) { + SmallVector offsets = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(oh * strideHeight), + rewriter.getIndexAttr(ow * strideWidth)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, x, offsets, sizes, strides); + + // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + rowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + im2colRows.push_back(row); + } + } + + // Concatenate all rows: [numPatches, patchSize] + Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); + + // Prepare weight matrix W for crossbar storage: + // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] + auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); + auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); + Value wFlat = tensor::CollapseShapeOp::create(rewriter, + loc, + wFlatType, + w, + SmallVector { + {0}, + {1, 2, 3} + }); + Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); + + // Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none + bool hasB = !isa(b.getDefiningOp()); + Value gemmC; + if (hasB) { + auto biasType = RankedTensorType::get({1, numChannelsOut}, cast(b.getType()).getElementType()); + gemmC = tensor::ExpandShapeOp::create(rewriter, + loc, + biasType, + b, + SmallVector { + {0, 1} + }); + } + else + gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + + // Gemm: A @ B + C = im2col @ W^T + b + // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] + auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); + auto gemmOp = ONNXGemmOp::create(rewriter, + loc, + gemmOutType, + im2col, + wTrans, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + Value gemmOut = gemmOp.getY(); + auto collectComputeOp = + spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector(), gemmOut); + + auto* collectBlock = new Block(); + collectBlock->addArgument(gemmOut.getType(), loc); + collectComputeOp.getBody().push_back(collectBlock); + rewriter.setInsertionPointToStart(collectBlock); + + auto gemmOutArg = collectBlock->getArguments().front(); + + // Restore to NCHW layout: + // [numPatches, numChannelsOut] + // -> [1, outHeight, outWidth, numChannelsOut] + // -> [1, numChannelsOut, outHeight, outWidth] + auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType()); + Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, + loc, + nhwcType, + gemmOutArg, + SmallVector { + {0, 1, 2}, + {3} + }); + Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); + + spatial::SpatYieldOp::create(rewriter, loc, nchwOut); + + rewriter.replaceOp(convOp, collectComputeOp); + return success(); } +void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp new file mode 100644 index 0000000..1ef9566 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/Support/LogicalResult.h" + +#include "src/Dialect/ONNX/ONNXOps.hpp" + +namespace onnx_mlir { + +struct ConvToGemm : mlir::OpConversionPattern { + ConvToGemm(mlir::MLIRContext* ctx) + : OpConversionPattern(ctx) {} + + mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp convOp, + mlir::ONNXConvOpAdaptor convOpAdaptor, + mlir::ConversionPatternRewriter& rewriter) const override; +}; + +void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp new file mode 100644 index 0000000..990bfd7 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp @@ -0,0 +1,583 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" + +#include +#include +#include +#include + +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace std; + +namespace onnx_mlir { + +// NOTE: +// This might be useful to re-implement this considering for loops. +// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; + +/** + * @brief A momentary representation of a core, to be used within the tiling of + * a convolution operation. + */ +class Core { +public: + Core(const size_t coreId, ConversionPatternRewriter& rewriter) + : coreId(coreId), rewriter(rewriter) {} + + /** + * @brief Add a MVM operation to the core. + * + * @param inputTile The input tile to the MVM operation. + * @param xbarIndex The index of the crossbar weight to use. + * @param outputTileId The id of the output tile. + * @param mvmOutType The result's shape. + * @return Value The result of the MVM operation. + */ + Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { + // Use the inputTile as the reference location for the MVM operation. + Location loc = inputTile.getLoc(); + + // Move the insertion point to the end of the block. + rewriter.setInsertionPointToEnd(block.get()); + + // Add the inputTile to the block arguments, and to the operands. + Value operand = operandMap.lookupOrNull(inputTile); + if (not operand) { + operand = block->addArgument(inputTile.getType(), loc); + operands.push_back(inputTile); + operandMap.map(inputTile, operand); + } + + // TODO: Compute the output type using the matrix, and check if `mvmOutType` + // is correct. + + // Construct the MVM operation + Value result = rewriter.create(loc, mvmOutType, xbarIndex, operand); + + // Since we are within the same core and no computation can happen in + // paralllel, we can just apply a linear reduction in case we have multiple + // MVM operations for the same outputTile. + auto lastMVM = outputTileToMVM.find(outputTileId); + + // If an entry for this outputTile already exists, apply reduction. + if (lastMVM != outputTileToMVM.end()) { + // MVM results should have the same type for reduction. + assert(lastMVM->second.getType() == result.getType()); + result = rewriter.create(loc, mvmOutType, lastMVM->second, result); + } + + outputTileToMVM[outputTileId] = result; + return result; + } + + /** + * @brief Mark a result as remappable, and return a shared pointer to it. + * + * This function marks a result as remappable, and returns a shared pointer to + * it. We need to keep track of these values to generate the YieldOp at a + * later stage. + * + * @param result A result to track, for later remapping. + * @return shared_ptr A shared pointer to the result. + */ + shared_ptr makeResultRemappable(Value result) { + // Verify that the result is present in the block. + assert(result.getDefiningOp()->getBlock() == block.get()); + + shared_ptr remappableResult = make_shared(result); + + resultsToRemap.push_back(remappableResult); + results.push_back(result); + + return remappableResult; + } + + /** + * @brief Add a remappable operand to the core, to merge partial results + * inter-core. + * + * @param remappableOperand The operand to add. + * @return Value The block argument representing the operand. + */ + Value addRemappableOperand(std::shared_ptr operand) { + // Check that the operand is not already there. + assert(not operandMap.contains(*operand)); + + Value argument = block->addArgument(operand->getType(), operand->getLoc()); + remappableOperands.push_back(operand); + return argument; + } + + /** + * @brief Generate a spatial::SpatWeightedCompute operation from the core. + * + * @param loc The location of the operation. + * @return spatial::SpatWeightedCompute + */ + spatial::SpatWeightedCompute createWComputeOp(Location loc) { + // Get the shape of the results. + SmallVector resultTypes; + for (const auto& value : results) + resultTypes.push_back(value.getType()); + + // Create the WComputeOp, with non-remappable operands only. + wcomputeOp = rewriter.create(loc, resultTypes, xbarWeights, operands); + + // Add the body to the WComputeOp. + Block* releasedBlock = block.release(); + wcomputeOp.getBody().push_back(releasedBlock); + + // Add the `yieldOp` at the end, with the results. + rewriter.setInsertionPointToEnd(releasedBlock); + rewriter.create(loc, results); + + return wcomputeOp; + } + + /** + * @brief Remap the results to the WComputeOp results. + */ + void remapResults() { + // Remap all the results to the WComputeOp results. + assert(resultsToRemap.size() == wcomputeOp->getNumResults()); + for (size_t i = 0; i < resultsToRemap.size(); i++) + *resultsToRemap[i] = wcomputeOp.getResult(i); + } + + void addRemappedOperands() { + // Insert the remappableOperands (which were remapped in + // `addRemappableOperand` of another Core) + for (auto remappedValue : remappableOperands) + wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); + + // Update the wcomputeOp operandSegmentSize + incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast(remappableOperands.size())); + } + + size_t addXbarWeight(Value weight) { + assert(!isXbarsFull()); + xbarWeights.push_back(weight); + return xbarWeights.size() - 1; + } + + bool isXbarsFull() { + assert(xbarWeights.size() <= crossbarCountInCore); + return xbarWeights.size() == crossbarCountInCore; + } + + bool isCoreEmpty() { return block->empty(); } + + void dump() { + // Print the coreId + llvm::outs() << "Core " << coreId << ":\n"; + // Print the weights + llvm::outs() << "Xbar Weights:\n"; + for (auto weight : xbarWeights) + weight.dump(); + // Print the operands + llvm::outs() << "Operands:\n"; + for (auto operand : operands) + llvm::outs() << operand << "\n"; + + // Dump the body block + for (auto& op : block->getOperations()) + op.dump(); + + // Print the results + llvm::outs() << "Results:\n"; + for (auto result : results) + llvm::outs() << result << "\n"; + } + + const size_t coreId; + +private: + ConversionPatternRewriter& rewriter; + + // Should these be set instead? But I need to keep the order + vector operands; + vector> remappableOperands; + + vector results; + vector> resultsToRemap; + + // Maps from input tiles to the block operand + IRMapping operandMap; + + // Map from outputTileId to MVM operation producing it + unordered_map outputTileToMVM; + + vector xbarWeights; + + unique_ptr block = make_unique(); + + spatial::SpatWeightedCompute wcomputeOp; +}; + +struct ConvToManyGemms : public OpConversionPattern { + ConvToManyGemms(MLIRContext* ctx) + : OpConversionPattern(ctx) {} + + struct Producer_t { + Value value; + shared_ptr core; + }; + + LogicalResult + matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { + ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); + ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); + ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); + ShapedType yShape = mlir::cast(conv.getY().getType()); + + size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; + unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); + unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); + + auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) + return rewriter.notifyMatchFailure(conv, padUnpackError.value()); + + // TODO: Pad value at beginning and end of each dimension could be + // different. We should handle this case. + + // MapOperations mapOperation = MapOperations::None; + // + // // If we have just one user, and it is an activation funcion (or more in + // // general a mapping operation) just inline it in the computeOps + // auto firstUserOp = *conv->getUsers().begin(); + // if (conv->hasOneUse()) { + // mapOperation = mlirOpToMapOperationEnum(firstUserOp); + // + // if (mapOperation == MapOperations::ONNXSoftmaxOp) { + // return rewriter.notifyMatchFailure( + // conv, "Softmax not supported as activation for convolutions."); + // } + // } + + size_t input_h = GET_IMAGE_HEIGHT(xShape); + size_t input_w = GET_IMAGE_WIDTH(xShape); + size_t output_h = GET_IMAGE_HEIGHT(yShape); + size_t output_w = GET_IMAGE_WIDTH(yShape); + size_t krn_h = GET_KERNEL_HEIGHT(wShape); + size_t krn_w = GET_KERNEL_WIDTH(wShape); + + Location loc = conv.getLoc(); + + size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; + size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); + size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; + + // Tile the input tensor + // Input tiles need to be indexed by: + // a. Channel Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: inputTiles[channelTile][x][y] + // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) + SmallVector>> inputTiles( + inputTileCount, SmallVector>(input_w, SmallVector(input_h))); + + auto resolveErrorOpt = resolveImgInputTiles( + convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter); + if (resolveErrorOpt.has_value()) + return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); + + SmallVector strides = SmallVector(4, rewriter.getIndexAttr(1)); + SmallVector offsets = SmallVector(4, rewriter.getIndexAttr(0)); + SmallVector sizes = SmallVector {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + + // Tile the weight tensor + // Weight tiles need to be indexed by: + // a. Filter Tile + // b. Channel Tile + // c. Kernel `x` position + // d. Kernel `y` position + // For example: weightTiles[filterTile][channelTile][x][y] + // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) + SmallVector>>> weightTiles( + outputTileCount, + SmallVector>>(inputTileCount, + SmallVector>(krn_w, SmallVector(krn_h)))); + strides = SmallVector(4, rewriter.getIndexAttr(1)); + offsets = SmallVector(4, rewriter.getIndexAttr(0)); + sizes = {rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + for (size_t i = 0; i < outputTileCount; i++) { + if (i == outputTileCount - 1 && outputTileRemainder != 0) + sizes[0] = rewriter.getIndexAttr(outputTileRemainder); + sizes[1] = rewriter.getIndexAttr(crossbarSize); + offsets[0] = rewriter.getIndexAttr(i * crossbarSize); + for (size_t j = 0; j < inputTileCount; j++) { + if (j == inputTileCount - 1 && inputTileRemainder != 0) + sizes[1] = rewriter.getIndexAttr(inputTileRemainder); + for (size_t x = 0; x < krn_w; x++) { + for (size_t y = 0; y < krn_h; y++) { + offsets[1] = rewriter.getIndexAttr(j * crossbarSize); + offsets[2] = rewriter.getIndexAttr(x); + offsets[3] = rewriter.getIndexAttr(y); + weightTiles[i][j][x][y] = + rewriter.create(loc, convAdaptor.getW(), offsets, sizes, strides); + } + } + } + } + + /* Distribute the computation among many compute cores + * Try to compute in-core the computation for each output tile, and reduce + * over as few cores as possible + */ + + // Tile the output tensor + // Output tiles need to be indexed by: + // a. Filter Tile + // b. Pixel `x` position + // c. Pixel `y` position + // For example: outputTiles[filterTile][x][y] + // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) + SmallVector>>> outputTiles( + outputTileCount, + SmallVector>>(output_w, SmallVector>(output_h, nullptr))); + + size_t replicationFactor; + if (!conv->hasAttr(REPLICATION_ATTR_NAME)) + replicationFactor = 1; + else + replicationFactor = conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); + // producers[outTile][out_x][out_y][producerIndex] + vector>>> producers = vector>>>( + outputTileCount, + vector>>(output_w, vector>(output_h, vector()))); + + // Schedule in cores + size_t coreId = 0; + vector> curCores(replicationFactor); + for (size_t i = 0; i < replicationFactor; i++) + curCores[i] = make_shared(coreId++, rewriter); + + vector> cores; + + const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor); + + for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { + for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { + + RankedTensorType mvmOutType = + RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, bShape.getElementType()); + + for (size_t outTile = 0; outTile < outputTileCount; outTile++) { + + if (outTile == outputTileCount - 1 && outputTileRemainder != 0) + mvmOutType = mvmOutType.clone({1, static_cast(outputTileRemainder), 1, 1}); + + for (size_t inTile = 0; inTile < inputTileCount; inTile++) { + + vector xbarIndexes(replicationFactor); + for (size_t i = 0; i < replicationFactor; i++) + xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]); + + size_t out_x = 0; + for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { + size_t out_y = 0; + + // I use `replicationFactor` cores. I divide the input_w into + // `replicationFactor` slices, and each slice is distributed to a + // core. `coreIndex` is the index of the core that will be used + // for this slice + size_t coreIndex = in_x / replicationSliceSize; + assert(coreIndex < replicationFactor); + + for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { + // Adjust the input based on the kernel + int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x; + int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y; + + // Check if we are within the input image + if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) { + out_y++; + continue; + } + + size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y; + auto mvm = curCores[coreIndex]->addMVM( + inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType); + + producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]}); + + out_y++; + } + out_x++; + } + + // Computations for these crossbars are done, check if the cores + // crossbars are fully used. If full, swap with new core + for (size_t i = 0; i < replicationFactor; i++) { + if (curCores[i]->isXbarsFull()) { + cores.emplace_back(std::move(curCores[i])); + curCores[i] = make_shared(coreId++, rewriter); + } + } + } + } + } + } + + for (auto& curCore : curCores) + if (curCore->isCoreEmpty() == false) + cores.emplace_back(std::move(curCore)); + curCores.clear(); + // Now, do the reduction of each output pixel tile + for (size_t outTile = 0; outTile < outputTileCount; outTile++) { + for (size_t out_x = 0; out_x < output_w; out_x++) { + for (size_t out_y = 0; out_y < output_h; out_y++) { + // First, check if some producers are within the same core. If this is + // true, `Core::addMVM` have already done the reduction within-core. + // This means that we only need to consider the last producer for that + // core. + + std::unordered_map withinCoreReducedProducers; + for (auto producer : producers[outTile][out_x][out_y]) + withinCoreReducedProducers[producer.core->coreId] = producer; + + // Now, we need to apply inter-core reduction + + // Base case with one producer + if (withinCoreReducedProducers.size() == 1) { + // TODO: Add the bias and apply mapping (if present) + + auto singleProducer = withinCoreReducedProducers.begin()->second; + // Use last producer as the final result + auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value); + outputTiles[outTile][out_x][out_y] = reducedValue; + continue; + } + + // TODO: This is a linear reduction, not a tree reduction. We can do + // better: a tree reduction would make more computations happen in + // parallel. + + Producer_t lastProducer = withinCoreReducedProducers.begin()->second; + + auto it = withinCoreReducedProducers.begin(); + it++; + while (it != withinCoreReducedProducers.end()) { + + Producer_t curProducer = it->second; + + shared_ptr core1; + shared_ptr core2; + Value core1Value; + Value core2Value; + + auto lastProducerCoreId = lastProducer.core->coreId; + auto curProducerCoreId = curProducer.core->coreId; + + assert(lastProducerCoreId != curProducerCoreId + && "We should have already applied within-core reduction, how " + "could we have same cores here?"); + + // Sort the cores by coreId + if (curProducerCoreId < lastProducerCoreId) { + core1 = curProducer.core; + core1Value = curProducer.value; + core2 = lastProducer.core; + core2Value = lastProducer.value; + } + else { + core1 = lastProducer.core; + core1Value = lastProducer.value; + core2 = curProducer.core; + core2Value = curProducer.value; + } + + auto newCoreRes = core1->makeResultRemappable(core1Value); + auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); + + rewriter.setInsertionPointAfterValue(core2Value); + Value vaddRes = rewriter.create( + core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); + + lastProducer = {vaddRes, core2}; + + it++; + } + + // TODO: Add the bias and apply mapping (if present) + + // Use last producer as the final result + auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); + outputTiles[outTile][out_x][out_y] = reducedValue; + } + } + } + + // Now, we need to turn the cores into a spatial::SpatWeightedCompute. + rewriter.setInsertionPointAfter(conv); + spatial::SpatWeightedCompute lastWComputeOp; + for (auto& core : cores) { + lastWComputeOp = core->createWComputeOp(loc); + core->remapResults(); + rewriter.setInsertionPointAfter(lastWComputeOp); + } + + for (auto& core : cores) + core->addRemappedOperands(); + + // Set the insertion point after the last WComputeOp. + rewriter.setInsertionPointAfter(lastWComputeOp); + SmallVector tilesToConcat; + tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); + for (size_t outX = 0; outX < output_h; outX++) + for (size_t outY = 0; outY < output_w; outY++) + for (size_t outTile = 0; outTile < outputTileCount; outTile++) + tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); + + Value outputImage = rewriter.create(loc, conv.getY().getType(), tilesToConcat); + + // Value outputImage = + // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); + + // If no mapping (activation) was applied, just replace ConvOp + // if (mapOperation == MapOperations::None) { + // rewriter.replaceOp(conv, outputImage); + // } else { + // // If mapping was applied, erase ConvOp and replace the mapping op + // rewriter.eraseOp(conv); + // rewriter.replaceOp(firstUserOp, outputImage); + // } + + return success(); + } +}; + +void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index cd63e75..1e13f8e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -2,18 +2,16 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" -#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include +#include "Gemm.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -23,394 +21,368 @@ using namespace mlir; namespace onnx_mlir { -const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; +LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = gemmOp.getLoc(); + Value a = gemmOpAdaptor.getA(); + Value b = gemmOpAdaptor.getB(); + Value c = gemmOpAdaptor.getC(); -struct GemmToManyGemv : OpConversionPattern { - GemmToManyGemv(MLIRContext* ctx) - : OpConversionPattern(ctx, 2) {} + assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); - LogicalResult - matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - Location loc = gemmOp.getLoc(); - Value a = adaptor.getA(); - Value b = adaptor.getB(); - Value c = adaptor.getC(); + bool hasC = !isa(c.getDefiningOp()); - assert("A should have been transposed already" && !adaptor.getTransA()); + auto aType = cast(a.getType()); + auto outType = cast(gemmOp.getY().getType()); + assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); - bool hasC = !isa(c.getDefiningOp()); + const int64_t numOutRows = aType.getDimSize(0); - auto aType = cast(a.getType()); - auto outType = cast(gemmOp.getY().getType()); - assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); + // Only decompose when there are multiple rows to split + if (numOutRows <= 1) + return failure(); - const int64_t numOutRows = aType.getDimSize(0); + RankedTensorType cType = nullptr; + bool cHasNumOutRows = false; + if (hasC) { + cType = cast(c.getType()); + assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + cHasNumOutRows = cType.getDimSize(0) == numOutRows; + } - // Only decompose when there are multiple rows to split - if (numOutRows <= 1) - return failure(); + auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); - RankedTensorType cType = nullptr; - bool cHasNumOutRows = false; + SmallVector gemvOps; + gemvOps.reserve(numOutRows); + for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { + SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); + auto aSlice = rewriter.create(loc, aSliceType, a, offsets, sizes, strides).getResult(); + + Value cSlice = c; if (hasC) { - cType = cast(c.getType()); - assert("Only support rank 2 tensor for C" && cType.getRank() == 2); - cHasNumOutRows = cType.getDimSize(0) == numOutRows; - } - - auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); - - SmallVector gemvOps; - gemvOps.reserve(numOutRows); - for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { - SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); - auto aSlice = rewriter.create(loc, aSliceType, a, offsets, sizes, strides).getResult(); - - Value cSlice = c; - if (hasC) { - if (cHasNumOutRows) { - SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); - cSlice = rewriter.create(loc, cSliceType, c, offsets, sizes, strides).getResult(); - } - else - assert("C should be a vector" && isVectorShape(getTensorShape(c))); + if (cHasNumOutRows) { + SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); + cSlice = rewriter.create(loc, cSliceType, c, offsets, sizes, strides).getResult(); } - - auto gemvOp = rewriter.create(loc, - outRowType, - aSlice, - b, - cSlice, - gemmOp.getAlphaAttr(), - gemmOp.getBetaAttr(), - gemmOp.getTransAAttr(), - gemmOp.getTransBAttr()); - gemvOps.push_back(gemvOp.getY()); + else + assert("C should be a vector" && isVectorShape(getTensorShape(c))); } - auto concatComputeOp = - rewriter.create(loc, gemmOp.getType(), SmallVector(), gemvOps); - - auto* concatBlock = new Block(); - for (auto gemvOp : gemvOps) - concatBlock->addArgument(gemvOp.getType(), loc); - concatComputeOp.getBody().push_back(concatBlock); - rewriter.setInsertionPointToStart(concatBlock); - - auto blockArgs = concatBlock->getArguments(); - auto concatOp = rewriter.create(loc, /*axis=*/0, blockArgs); - rewriter.create(loc, concatOp.getResult()); - - rewriter.replaceOp(gemmOp, concatComputeOp); - return success(); + auto gemvOp = rewriter.create(loc, + outRowType, + aSlice, + b, + cSlice, + gemmOp.getAlphaAttr(), + gemmOp.getBetaAttr(), + gemmOp.getTransAAttr(), + gemmOp.getTransBAttr()); + gemvOps.push_back(gemvOp.getY()); } -}; -struct GemvToSpatialCompute : OpConversionPattern { - GemvToSpatialCompute(MLIRContext* ctx) - : OpConversionPattern(ctx, 1) {} + auto concatComputeOp = + rewriter.create(loc, gemmOp.getType(), SmallVector(), gemvOps); - LogicalResult - matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - Location gemmLoc = gemmOp.getLoc(); - Value a = adaptor.getA(); - Value b = adaptor.getB(); - Value c = adaptor.getC(); - Value out = gemmOp.getY(); + auto* concatBlock = new Block(); + for (auto gemvOp : gemvOps) + concatBlock->addArgument(gemvOp.getType(), loc); + concatComputeOp.getBody().push_back(concatBlock); + rewriter.setInsertionPointToStart(concatBlock); - float alpha = adaptor.getAlpha().convertToFloat(); - float beta = adaptor.getBeta().convertToFloat(); - bool transA = adaptor.getTransA(); - bool transB = adaptor.getTransB(); + auto blockArgs = concatBlock->getArguments(); + auto concatOp = rewriter.create(loc, /*axis=*/0, blockArgs); + rewriter.create(loc, concatOp.getResult()); - auto aType = cast(a.getType()); - auto bType = cast(b.getType()); - auto outType = cast(out.getType()); + rewriter.replaceOp(gemmOp, concatComputeOp); + return success(); +} + +LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location gemmLoc = gemmOp.getLoc(); + Value a = gemmOpAdaptor.getA(); + Value b = gemmOpAdaptor.getB(); + Value c = gemmOpAdaptor.getC(); + Value out = gemmOp.getY(); + + float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); + float beta = gemmOpAdaptor.getBeta().convertToFloat(); + bool transA = gemmOpAdaptor.getTransA(); + bool transB = gemmOpAdaptor.getTransB(); + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto outType = cast(out.getType()); + + RankedTensorType cType = nullptr; + bool hasC = !isa(c.getDefiningOp()); + if (hasC) { + cType = cast(c.getType()); + assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + } + + assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() + && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); + + if (!isVectorShape(aType.getShape()) || !isVectorShape(cType.getShape())) + // Not a gemv + return failure(); + + if (transA) { + auto aShape = aType.getShape(); + auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); + a = rewriter.create(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); + } + if (transB) { + auto bShape = bType.getShape(); + auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); + b = rewriter.create(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + } + + if (alpha != 1.0f) { + auto alphaTensorType = RankedTensorType::get({1, 1}, cast(a.getType()).getElementType()); + auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); + auto alphaTensor = rewriter.create(gemmLoc, alphaTensorType, alphaTensorValue); + a = rewriter.create(gemmLoc, a.getType(), a, alphaTensor); + } + if (hasC && beta != 1.0f) { + auto betaTensorType = RankedTensorType::get({1, 1}, cast(c.getType()).getElementType()); + auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); + auto betaTensor = rewriter.create(gemmLoc, betaTensorType, betaTensorValue); + c = rewriter.create(gemmLoc, c.getType(), c, betaTensor); + } + + auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); + auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); + auto bNumVSlices = aNumHSlices; + auto bLastVSliceSize = aLastHSliceSize; + auto cNumHSlices = bNumHSlices; + auto cLastHSliceSize = bLastHSliceSize; + auto outNumHSlices = cNumHSlices; + auto outLastHSliceSize = cLastHSliceSize; + + const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); + + DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); + + DenseMap>> bTiles = + tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); + + SmallVector cHSlices; + if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) + c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); + if (hasC) + cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); + + RankedTensorType outHSliceType = + RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); + RankedTensorType outLastHSliceType = + RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); + + SmallVector outHSlices; + outHSlices.reserve(outNumHSlices); + for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { + RankedTensorType currOutHSliceType = outHSliceType; + if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) + currOutHSliceType = outLastHSliceType; + + SmallVector partialResults; + partialResults.reserve(coresPerVSlice); + for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { + SmallVector weights; + weights.reserve(aHSlices[coreId].size()); + + for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) + weights.push_back(bTiles[outSliceId][coreId][aSliceId]); + + auto computeOp = + rewriter.create(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); + + auto* computeBlock = new Block(); + for (auto aHSlice : aHSlices[coreId]) + computeBlock->addArgument(aHSlice.getType(), gemmLoc); + computeOp.getBody().push_back(computeBlock); + rewriter.setInsertionPointToStart(computeBlock); + + auto computeArgs = computeBlock->getArguments(); + SmallVector vmmOutputs; + vmmOutputs.reserve(computeArgs.size()); + for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) + vmmOutputs.push_back( + rewriter.create(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); + assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); + + Value partialVmmSum = sumTensors(vmmOutputs, rewriter); + rewriter.create(gemmLoc, partialVmmSum); + rewriter.setInsertionPointAfter(computeOp); + + partialResults.push_back(computeOp.getResult(0)); + } - RankedTensorType cType = nullptr; - bool hasC = !isa(c.getDefiningOp()); if (hasC) { - cType = cast(c.getType()); - assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + Value cHSlice = cHSlices[outSliceId]; + partialResults.push_back(cHSlice); } - assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() - && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); + auto reduceComputeOp = + rewriter.create(gemmLoc, currOutHSliceType, SmallVector(), partialResults); - if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape())) - // Not a gemv - return failure(); + auto* reduceBlock = new Block(); + for (auto partialResult : partialResults) + reduceBlock->addArgument(partialResult.getType(), gemmLoc); + reduceComputeOp.getBody().push_back(reduceBlock); + rewriter.setInsertionPointToStart(reduceBlock); - if (transA) { - auto aShape = aType.getShape(); - auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); - a = rewriter.create(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); - } - if (transB) { - auto bShape = bType.getShape(); - auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); - b = rewriter.create(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); - } + auto blockArgs = reduceBlock->getArguments(); + Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); + rewriter.create(gemmLoc, outHSlice); + rewriter.setInsertionPointAfter(reduceComputeOp); - if (alpha != 1.0f) { - auto alphaTensorType = RankedTensorType::get({1, 1}, cast(a.getType()).getElementType()); - auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); - auto alphaTensor = rewriter.create(gemmLoc, alphaTensorType, alphaTensorValue); - a = rewriter.create(gemmLoc, a.getType(), a, alphaTensor); - } - if (hasC && beta != 1.0f) { - auto betaTensorType = RankedTensorType::get({1, 1}, cast(c.getType()).getElementType()); - auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); - auto betaTensor = rewriter.create(gemmLoc, betaTensorType, betaTensorValue); - c = rewriter.create(gemmLoc, c.getType(), c, betaTensor); - } - - auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); - auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); - auto bNumVSlices = aNumHSlices; - auto bLastVSliceSize = aLastHSliceSize; - auto cNumHSlices = bNumHSlices; - auto cLastHSliceSize = bLastHSliceSize; - auto outNumHSlices = cNumHSlices; - auto outLastHSliceSize = cLastHSliceSize; - - const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); - - DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); - - DenseMap>> bTiles = - tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); - - SmallVector cHSlices; - if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) - c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); - if (hasC) - cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); - - RankedTensorType outHSliceType = - RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); - RankedTensorType outLastHSliceType = - RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); - - SmallVector outHSlices; - outHSlices.reserve(outNumHSlices); - for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { - RankedTensorType currOutHSliceType = outHSliceType; - if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) - currOutHSliceType = outLastHSliceType; - - SmallVector partialResults; - partialResults.reserve(coresPerVSlice); - for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { - SmallVector weights; - weights.reserve(aHSlices[coreId].size()); - - for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) - weights.push_back(bTiles[outSliceId][coreId][aSliceId]); - - auto computeOp = - rewriter.create(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); - - auto* computeBlock = new Block(); - for (auto aHSlice : aHSlices[coreId]) - computeBlock->addArgument(aHSlice.getType(), gemmLoc); - computeOp.getBody().push_back(computeBlock); - rewriter.setInsertionPointToStart(computeBlock); - - auto computeArgs = computeBlock->getArguments(); - SmallVector vmmOutputs; - vmmOutputs.reserve(computeArgs.size()); - for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) - vmmOutputs.push_back( - rewriter.create(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); - assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); - - Value partialVmmSum = sumTensors(vmmOutputs, rewriter); - rewriter.create(gemmLoc, partialVmmSum); - rewriter.setInsertionPointAfter(computeOp); - - partialResults.push_back(computeOp.getResult(0)); - } - - if (hasC) { - Value cHSlice = cHSlices[outSliceId]; - partialResults.push_back(cHSlice); - } - - auto reduceComputeOp = - rewriter.create(gemmLoc, currOutHSliceType, SmallVector(), partialResults); - - auto* reduceBlock = new Block(); - for (auto partialResult : partialResults) - reduceBlock->addArgument(partialResult.getType(), gemmLoc); - reduceComputeOp.getBody().push_back(reduceBlock); - rewriter.setInsertionPointToStart(reduceBlock); - - auto blockArgs = reduceBlock->getArguments(); - Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); - rewriter.create(gemmLoc, outHSlice); - rewriter.setInsertionPointAfter(reduceComputeOp); - - outHSlices.push_back(reduceComputeOp.getResult(0)); - } - - auto concatComputeOp = - rewriter.create(gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); - - auto* concatBlock = new Block(); - for (auto outHSlice : outHSlices) - concatBlock->addArgument(outHSlice.getType(), gemmLoc); - concatComputeOp.getBody().push_back(concatBlock); - rewriter.setInsertionPointToStart(concatBlock); - - auto blockArgs = concatBlock->getArguments(); - auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, blockArgs); - rewriter.create(gemmLoc, concatOp.getResult()); - - rewriter.replaceOp(gemmOp, concatComputeOp); - return success(); + outHSlices.push_back(reduceComputeOp.getResult(0)); } -private: - /** - * Resolves the ONNXExpOp from the use chain of the given start value. - * - * This function traverses the use chain of the start value until it finds an - * ONNXExpOp. It returns the value of the ONNXExpOp. - * - * @param startValue The starting value of the use chain. - * @return The value of the ONNXExpOp found in the use chain. - */ - static Value resolveONNXExpOpFromUseChain(Value startValue) { - Value walker = startValue; + auto concatComputeOp = + rewriter.create(gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); - while (!llvm::isa(walker.getDefiningOp())) { - walker = walker.getDefiningOp()->getOperand(0); + auto* concatBlock = new Block(); + for (auto outHSlice : outHSlices) + concatBlock->addArgument(outHSlice.getType(), gemmLoc); + concatComputeOp.getBody().push_back(concatBlock); + rewriter.setInsertionPointToStart(concatBlock); - assert(walker && walker.getDefiningOp() - && "Unwinded the whole chain of operations while trying to " - "find ONNXExpOp, but did not find it"); - } + auto blockArgs = concatBlock->getArguments(); + auto concatOp = rewriter.create(gemmLoc, /*axis=*/1, blockArgs); + rewriter.create(gemmLoc, concatOp.getResult()); - // Make sure the dividend is actually produced by an ONNXExpOp - assert(llvm::isa(walker.getDefiningOp()) - && "Old output tile (softmax reducer) is not produced by an " - "ONNXExpOp"); + rewriter.replaceOp(gemmOp, concatComputeOp); + return success(); +} - return walker; +Value GemvToSpatialCompute::resolveONNXExpOpFromUseChain(Value startValue) { + Value walker = startValue; + + while (!llvm::isa(walker.getDefiningOp())) { + walker = walker.getDefiningOp()->getOperand(0); + + assert(walker && walker.getDefiningOp() + && "Unwinded the whole chain of operations while trying to " + "find ONNXExpOp, but did not find it"); } - // Softmax is a special case, as it requires another reduction after the - // first one. In the cores, `applyReducePattern` already applied - // f(x) = exp(x) to each tile. This mean that now we just need to - // reduce-sum these tiles, and then divide each tile by the reduced sum, - // which is propagated back to the cores via a broadcast channel. - LogicalResult softmaxReductionApplication(SmallVector& outputOpsAndResNums, - Value& softmaxChannel, - ConversionPatternRewriter& rewriter, - SpatialReducer& reducer, - ONNXGemmOp& gemmOp, - Location& loc) const { + // Make sure the dividend is actually produced by an ONNXExpOp + assert(llvm::isa(walker.getDefiningOp()) + && "Old output tile (softmax reducer) is not produced by an " + "ONNXExpOp"); - // TODO: Check case with one compute op + return walker; +} - // Cast vector of Value into vector of ComputeOp - SmallVector softmaxOpsToReduce = - llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) { - return std::make_pair(cast(computeAndResNum.first), computeAndResNum.second); - })); +LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector& outputOpsAndResNums, + Value& softmaxChannel, + ConversionPatternRewriter& rewriter, + SpatialReducer& reducer, + ONNXGemmOp& gemmOp, + Location& loc) { + // TODO: Check case with one compute op - RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr); - const TensorType scalarTensorType = tensorTypeBuilder; + // Cast vector of Value into vector of ComputeOp + SmallVector softmaxOpsToReduce = + llvm::to_vector(llvm::map_range(outputOpsAndResNums, [&](OpAndResNum computeAndResNum) { + return std::make_pair(cast(computeAndResNum.first), computeAndResNum.second); + })); - reducer.applyReducePattern( - softmaxOpsToReduce, - [&](Value a, Value b) { return rewriter.create(loc, scalarTensorType, a, b); }, - /* preprocess = */ - [&](Value a) { return rewriter.create(loc, scalarTensorType, a); }, - [&](Value softmaxDivisor) { - // Signal that this is the compute with the softmax divisor - auto computeOp = cast(softmaxDivisor.getDefiningOp()->getParentOp()); - computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr()); + RankedTensorType::Builder tensorTypeBuilder({1}, Float32Type::get(rewriter.getContext()), nullptr); + const TensorType scalarTensorType = tensorTypeBuilder; - // Broadcast the divisor to all the cores - rewriter.setInsertionPointAfterValue(softmaxDivisor); - rewriter.create(loc, softmaxChannel, softmaxDivisor); + reducer.applyReducePattern( + softmaxOpsToReduce, + [&](Value a, Value b) { return rewriter.create(loc, scalarTensorType, a, b); }, + /* preprocess = */ + [&](Value a) { return rewriter.create(loc, scalarTensorType, a); }, + [&](Value softmaxDivisor) { + // Signal that this is the compute with the softmax divisor + auto computeOp = cast(softmaxDivisor.getDefiningOp()->getParentOp()); + computeOp->setAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME, rewriter.getUnitAttr()); - /* - * softmaxDividend = onnx.exp (...) - * sum = spat.SumOp(softmaxDividend) - * [following can be repeated N times, thus walk the use chain] - * softmaxDivisor = spat.sadd(sum, ...) - */ - Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0)); + // Broadcast the divisor to all the cores + rewriter.setInsertionPointAfterValue(softmaxDivisor); + rewriter.create(loc, softmaxChannel, softmaxDivisor); - // Make sure the dividend is actually produced by an ONNXExpOp - assert(llvm::isa(softmaxDividend.getDefiningOp()) - && "Dividend of softmax reduction is not an ONNXExpOp"); + /* + * softmaxDividend = onnx.exp (...) + * sum = spat.SumOp(softmaxDividend) + * [following can be repeated N times, thus walk the use chain] + * softmaxDivisor = spat.sadd(sum, ...) + */ + Value softmaxDividend = resolveONNXExpOpFromUseChain(softmaxDivisor.getDefiningOp()->getOperand(0)); - // Do not divide here, divide after this - return softmaxDivisor; - }); + // Make sure the dividend is actually produced by an ONNXExpOp + assert(llvm::isa(softmaxDividend.getDefiningOp()) + && "Dividend of softmax reduction is not an ONNXExpOp"); - // In all the cores, insert a ChannelRecvOp and divide the output tile by - // the reduced denominator. - outputOpsAndResNums.clear(); - outputOpsAndResNums.reserve(softmaxOpsToReduce.size()); - for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) { + // Do not divide here, divide after this + return softmaxDivisor; + }); - auto yieldOp = cast(computeToDivideOpAndResNum.first.getBody().front().getTerminator()); + // In all the cores, insert a ChannelRecvOp and divide the output tile by + // the reduced denominator. + outputOpsAndResNums.clear(); + outputOpsAndResNums.reserve(softmaxOpsToReduce.size()); + for (auto& computeToDivideOpAndResNum : softmaxOpsToReduce) { - Value divisor; + auto yieldOp = cast(computeToDivideOpAndResNum.first.getBody().front().getTerminator()); - // Check if this compute contains the softmax divisor: if so, find the - // ChannelBroadcastSendOp, otherwise receive the value from the channel - // using ChannelBroadcastReceiveOp - if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) { + Value divisor; - bool found = false; - for (auto broadcastOp : - computeToDivideOpAndResNum.first.getBody().front().getOps()) { - assert(found == false - && "More than one ChannelBroadcastSendOp in " - "compute? How is this possible?"); - found = true; + // Check if this compute contains the softmax divisor: if so, find the + // ChannelBroadcastSendOp, otherwise receive the value from the channel + // using ChannelBroadcastReceiveOp + if (computeToDivideOpAndResNum.first->hasAttr(COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME)) { - divisor = broadcastOp.getData(); - } + bool found = false; + for (auto broadcastOp : + computeToDivideOpAndResNum.first.getBody().front().getOps()) { + assert(found == false + && "More than one ChannelBroadcastSendOp in " + "compute? How is this possible?"); + found = true; - assert(found - && "No ChannelBroadcastSendOp in compute where softmax " - "divisor was specified to be?"); - } - else { - rewriter.setInsertionPoint(yieldOp); - divisor = rewriter.create(loc, scalarTensorType, softmaxChannel); + divisor = broadcastOp.getData(); } - // Walk the chain of operations until we find the ONNXExpOp: this is - // needed because some some may have a different amount of `VAddOp`s due - // to the tree reduction (e.g. some may have no VAddOp, some may have - // multiples) - Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second)); - + assert(found + && "No ChannelBroadcastSendOp in compute where softmax " + "divisor was specified to be?"); + } + else { rewriter.setInsertionPoint(yieldOp); - Value newOutputTile = rewriter.create(loc, oldOutputTile.getType(), oldOutputTile, divisor); - auto yieldOperandNum = yieldOp->getNumOperands(); - yieldOp->insertOperands(yieldOperandNum, newOutputTile); - - outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum}); + divisor = rewriter.create(loc, scalarTensorType, softmaxChannel); } - return success(); + // Walk the chain of operations until we find the ONNXExpOp: this is + // needed because some some may have a different amount of `VAddOp`s due + // to the tree reduction (e.g. some may have no VAddOp, some may have + // multiples) + Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second)); + + rewriter.setInsertionPoint(yieldOp); + Value newOutputTile = rewriter.create(loc, oldOutputTile.getType(), oldOutputTile, divisor); + auto yieldOperandNum = yieldOp->getNumOperands(); + yieldOp->insertOperands(yieldOperandNum, newOutputTile); + + outputOpsAndResNums.push_back({computeToDivideOpAndResNum.first, yieldOperandNum}); } -}; + + return success(); +} void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp new file mode 100644 index 0000000..2853674 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +namespace onnx_mlir { + +constexpr mlir::StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; + +struct GemmToManyGemv : mlir::OpConversionPattern { + GemmToManyGemv(mlir::MLIRContext* ctx) + : OpConversionPattern(ctx, 2) {} + + mlir::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp, + mlir::ONNXGemmOpAdaptor gemmOpAdaptor, + mlir::ConversionPatternRewriter& rewriter) const override; +}; + +struct GemvToSpatialCompute : mlir::OpConversionPattern { + GemvToSpatialCompute(mlir::MLIRContext* ctx) + : OpConversionPattern(ctx, 1) {} + + llvm::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp, + mlir::ONNXGemmOpAdaptor gemmOpAdaptor, + mlir::ConversionPatternRewriter& rewriter) const override; + +private: + /** + * Resolves the ONNXExpOp from the use chain of the given start value. + * + * This function traverses the use chain of the start value until it finds an + * ONNXExpOp. It returns the value of the ONNXExpOp. + * + * @param startValue The starting value of the use chain. + * @return The value of the ONNXExpOp found in the use chain. + */ + static mlir::Value resolveONNXExpOpFromUseChain(mlir::Value startValue); + + // Softmax is a special case, as it requires another reduction after the + // first one. In the cores, `applyReducePattern` already applied + // f(x) = exp(x) to each tile. This mean that now we just need to + // reduce-sum these tiles, and then divide each tile by the reduced sum, + // which is propagated back to the cores via a broadcast channel. + static llvm::LogicalResult softmaxReductionApplication(llvm::SmallVector& outputOpsAndResNums, + Value& softmaxChannel, + ConversionPatternRewriter& rewriter, + SpatialReducer& reducer, + ONNXGemmOp& gemmOp, + Location& loc); +}; + +void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 182c715..af75fd9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -10,6 +10,7 @@ #include "Common/PIMCommon.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" +#include "Math/Conv.hpp" #include "ONNXToSpatialPass.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp index b496dfd..416f096 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -6,7 +6,6 @@ namespace onnx_mlir { void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index 9a8ed76..a64c26b 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -29,12 +30,12 @@ void SpatialToPIMPass::runOnOperation() { MLIRContext* ctx = moduleOp.getContext(); ConversionTarget target(*ctx); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(ctx); populateWithGenerated(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } diff --git a/validation/operations/conv/batch_64/conv_batch_64.onnx b/validation/operations/conv/batch_64/conv_batch_64.onnx new file mode 100644 index 0000000..d5775f7 Binary files /dev/null and b/validation/operations/conv/batch_64/conv_batch_64.onnx differ diff --git a/validation/operations/conv/simple/conv.onnx b/validation/operations/conv/simple/conv.onnx new file mode 100644 index 0000000..99a5da2 Binary files /dev/null and b/validation/operations/conv/simple/conv.onnx differ diff --git a/validation/operations/conv/with_constant/conv_with_constant.onnx b/validation/operations/conv/with_constant/conv_with_constant.onnx new file mode 100644 index 0000000..c126a92 Binary files /dev/null and b/validation/operations/conv/with_constant/conv_with_constant.onnx differ