#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/LogicalResult.h" #include #include #include #include #include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; using namespace std; namespace onnx_mlir { // NOTE: // This might be useful to re-implement this considering for loops. // neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; /** * @brief A momentary representation of a core, to be used within the tiling of * a convolution operation. */ class Core { public: Core(const size_t coreId, ConversionPatternRewriter& rewriter) : coreId(coreId), rewriter(rewriter) {} /** * @brief Add a MVM operation to the core. * * @param inputTile The input tile to the MVM operation. * @param xbarIndex The index of the crossbar weight to use. * @param outputTileId The id of the output tile. * @param mvmOutType The result's shape. * @return Value The result of the MVM operation. */ Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { // Use the inputTile as the reference location for the MVM operation. Location loc = inputTile.getLoc(); // Move the insertion point to the end of the block. rewriter.setInsertionPointToEnd(block.get()); // Add the inputTile to the block arguments, and to the operands. Value operand = operandMap.lookupOrNull(inputTile); if (not operand) { operand = block->addArgument(inputTile.getType(), loc); operands.push_back(inputTile); operandMap.map(inputTile, operand); } // TODO: Compute the output type using the matrix, and check if `mvmOutType` // is correct. // Construct the MVM operation Value result = rewriter.create(loc, mvmOutType, xbarIndex, operand); // Since we are within the same core and no computation can happen in // paralllel, we can just apply a linear reduction in case we have multiple // MVM operations for the same outputTile. auto lastMVM = outputTileToMVM.find(outputTileId); // If an entry for this outputTile already exists, apply reduction. if (lastMVM != outputTileToMVM.end()) { // MVM results should have the same type for reduction. assert(lastMVM->second.getType() == result.getType()); result = rewriter.create(loc, mvmOutType, lastMVM->second, result); } outputTileToMVM[outputTileId] = result; return result; } /** * @brief Mark a result as remappable, and return a shared pointer to it. * * This function marks a result as remappable, and returns a shared pointer to * it. We need to keep track of these values to generate the YieldOp at a * later stage. * * @param result A result to track, for later remapping. * @return shared_ptr A shared pointer to the result. */ shared_ptr makeResultRemappable(Value result) { // Verify that the result is present in the block. assert(result.getDefiningOp()->getBlock() == block.get()); shared_ptr remappableResult = make_shared(result); resultsToRemap.push_back(remappableResult); results.push_back(result); return remappableResult; } /** * @brief Add a remappable operand to the core, to merge partial results * inter-core. * * @param remappableOperand The operand to add. * @return Value The block argument representing the operand. */ Value addRemappableOperand(std::shared_ptr operand) { // Check that the operand is not already there. assert(not operandMap.contains(*operand)); Value argument = block->addArgument(operand->getType(), operand->getLoc()); remappableOperands.push_back(operand); return argument; } /** * @brief Generate a spatial::SpatWeightedCompute operation from the core. * * @param loc The location of the operation. * @return spatial::SpatWeightedCompute */ spatial::SpatWeightedCompute createWComputeOp(Location loc) { // Get the shape of the results. SmallVector resultTypes; for (const auto& value : results) resultTypes.push_back(value.getType()); // Create the WComputeOp, with non-remappable operands only. wcomputeOp = rewriter.create(loc, resultTypes, xbarWeights, operands); // Add the body to the WComputeOp. Block* releasedBlock = block.release(); wcomputeOp.getBody().push_back(releasedBlock); // Add the `yieldOp` at the end, with the results. rewriter.setInsertionPointToEnd(releasedBlock); rewriter.create(loc, results); return wcomputeOp; } /** * @brief Remap the results to the WComputeOp results. */ void remapResults() { // Remap all the results to the WComputeOp results. assert(resultsToRemap.size() == wcomputeOp->getNumResults()); for (size_t i = 0; i < resultsToRemap.size(); i++) *resultsToRemap[i] = wcomputeOp.getResult(i); } void addRemappedOperands() { // Insert the remappableOperands (which were remapped in // `addRemappableOperand` of another Core) for (auto remappedValue : remappableOperands) wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); // Update the wcomputeOp operandSegmentSize incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast(remappableOperands.size())); } size_t addXbarWeight(Value weight) { assert(!isXbarsFull()); xbarWeights.push_back(weight); return xbarWeights.size() - 1; } bool isXbarsFull() { assert(xbarWeights.size() <= crossbarCountInCore); return xbarWeights.size() == crossbarCountInCore; } bool isCoreEmpty() { return block->empty(); } void dump() { // Print the coreId llvm::outs() << "Core " << coreId << ":\n"; // Print the weights llvm::outs() << "Xbar Weights:\n"; for (auto weight : xbarWeights) weight.dump(); // Print the operands llvm::outs() << "Operands:\n"; for (auto operand : operands) llvm::outs() << operand << "\n"; // Dump the body block for (auto& op : block->getOperations()) op.dump(); // Print the results llvm::outs() << "Results:\n"; for (auto result : results) llvm::outs() << result << "\n"; } const size_t coreId; private: ConversionPatternRewriter& rewriter; // Should these be set instead? But I need to keep the order vector operands; vector> remappableOperands; vector results; vector> resultsToRemap; // Maps from input tiles to the block operand IRMapping operandMap; // Map from outputTileId to MVM operation producing it unordered_map outputTileToMVM; vector xbarWeights; unique_ptr block = make_unique(); spatial::SpatWeightedCompute wcomputeOp; }; struct ONNXConvOpTile : public OpConversionPattern { ONNXConvOpTile(MLIRContext* ctx) : OpConversionPattern(ctx) {} struct Producer_t { Value value; shared_ptr core; }; LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); ShapedType yShape = mlir::cast(conv.getY().getType()); size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); if (padUnpackError.has_value()) return rewriter.notifyMatchFailure(conv, padUnpackError.value()); // TODO: Pad value at beginning and end of each dimension could be // different. We should handle this case. // MapOperations mapOperation = MapOperations::None; // // // If we have just one user, and it is an activation funcion (or more in // // general a mapping operation) just inline it in the computeOps // auto firstUserOp = *conv->getUsers().begin(); // if (conv->hasOneUse()) { // mapOperation = mlirOpToMapOperationEnum(firstUserOp); // // if (mapOperation == MapOperations::ONNXSoftmaxOp) { // return rewriter.notifyMatchFailure( // conv, "Softmax not supported as activation for convolutions."); // } // } size_t input_h = GET_IMAGE_HEIGHT(xShape); size_t input_w = GET_IMAGE_WIDTH(xShape); size_t output_h = GET_IMAGE_HEIGHT(yShape); size_t output_w = GET_IMAGE_WIDTH(yShape); size_t krn_h = GET_KERNEL_HEIGHT(wShape); size_t krn_w = GET_KERNEL_WIDTH(wShape); Location loc = conv.getLoc(); size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; // Tile the input tensor // Input tiles need to be indexed by: // a. Channel Tile // b. Pixel `x` position // c. Pixel `y` position // For example: inputTiles[channelTile][x][y] // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) SmallVector>> inputTiles( inputTileCount, SmallVector>(input_w, SmallVector(input_h))); auto resolveErrorOpt = resolveImgInputTiles( convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter); if (resolveErrorOpt.has_value()) return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); SmallVector strides = SmallVector(4, rewriter.getIndexAttr(1)); SmallVector offsets = SmallVector(4, rewriter.getIndexAttr(0)); SmallVector sizes = SmallVector {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; // Tile the weight tensor // Weight tiles need to be indexed by: // a. Filter Tile // b. Channel Tile // c. Kernel `x` position // d. Kernel `y` position // For example: weightTiles[filterTile][channelTile][x][y] // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) SmallVector>>> weightTiles( outputTileCount, SmallVector>>(inputTileCount, SmallVector>(krn_w, SmallVector(krn_h)))); strides = SmallVector(4, rewriter.getIndexAttr(1)); offsets = SmallVector(4, rewriter.getIndexAttr(0)); sizes = {rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; for (size_t i = 0; i < outputTileCount; i++) { if (i == outputTileCount - 1 && outputTileRemainder != 0) sizes[0] = rewriter.getIndexAttr(outputTileRemainder); sizes[1] = rewriter.getIndexAttr(crossbarSize); offsets[0] = rewriter.getIndexAttr(i * crossbarSize); for (size_t j = 0; j < inputTileCount; j++) { if (j == inputTileCount - 1 && inputTileRemainder != 0) sizes[1] = rewriter.getIndexAttr(inputTileRemainder); for (size_t x = 0; x < krn_w; x++) { for (size_t y = 0; y < krn_h; y++) { offsets[1] = rewriter.getIndexAttr(j * crossbarSize); offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); weightTiles[i][j][x][y] = rewriter.create(loc, convAdaptor.getW(), offsets, sizes, strides); } } } } /* Distribute the computation among many compute cores * Try to compute in-core the computation for each output tile, and reduce * over as few cores as possible */ // Tile the output tensor // Output tiles need to be indexed by: // a. Filter Tile // b. Pixel `x` position // c. Pixel `y` position // For example: outputTiles[filterTile][x][y] // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) SmallVector>>> outputTiles( outputTileCount, SmallVector>>(output_w, SmallVector>(output_h, nullptr))); size_t replicationFactor; if (!conv->hasAttr(REPLICATION_ATTR_NAME)) replicationFactor = 1; else replicationFactor = conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); // producers[outTile][out_x][out_y][producerIndex] vector>>> producers = vector>>>( outputTileCount, vector>>(output_w, vector>(output_h, vector()))); // Schedule in cores size_t coreId = 0; vector> curCores(replicationFactor); for (size_t i = 0; i < replicationFactor; i++) curCores[i] = make_shared(coreId++, rewriter); vector> cores; const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor); for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { RankedTensorType mvmOutType = RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, bShape.getElementType()); for (size_t outTile = 0; outTile < outputTileCount; outTile++) { if (outTile == outputTileCount - 1 && outputTileRemainder != 0) mvmOutType = mvmOutType.clone({1, static_cast(outputTileRemainder), 1, 1}); for (size_t inTile = 0; inTile < inputTileCount; inTile++) { vector xbarIndexes(replicationFactor); for (size_t i = 0; i < replicationFactor; i++) xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]); size_t out_x = 0; for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { size_t out_y = 0; // I use `replicationFactor` cores. I divide the input_w into // `replicationFactor` slices, and each slice is distributed to a // core. `coreIndex` is the index of the core that will be used // for this slice size_t coreIndex = in_x / replicationSliceSize; assert(coreIndex < replicationFactor); for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { // Adjust the input based on the kernel int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x; int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y; // Check if we are within the input image if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) { out_y++; continue; } size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y; auto mvm = curCores[coreIndex]->addMVM( inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType); producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]}); out_y++; } out_x++; } // Computations for these crossbars are done, check if the cores // crossbars are fully used. If full, swap with new core for (size_t i = 0; i < replicationFactor; i++) { if (curCores[i]->isXbarsFull()) { cores.emplace_back(std::move(curCores[i])); curCores[i] = make_shared(coreId++, rewriter); } } } } } } for (auto& curCore : curCores) if (curCore->isCoreEmpty() == false) cores.emplace_back(std::move(curCore)); curCores.clear(); // Now, do the reduction of each output pixel tile for (size_t outTile = 0; outTile < outputTileCount; outTile++) { for (size_t out_x = 0; out_x < output_w; out_x++) { for (size_t out_y = 0; out_y < output_h; out_y++) { // First, check if some producers are within the same core. If this is // true, `Core::addMVM` have already done the reduction within-core. // This means that we only need to consider the last producer for that // core. std::unordered_map withinCoreReducedProducers; for (auto producer : producers[outTile][out_x][out_y]) withinCoreReducedProducers[producer.core->coreId] = producer; // Now, we need to apply inter-core reduction // Base case with one producer if (withinCoreReducedProducers.size() == 1) { // TODO: Add the bias and apply mapping (if present) auto singleProducer = withinCoreReducedProducers.begin()->second; // Use last producer as the final result auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value); outputTiles[outTile][out_x][out_y] = reducedValue; continue; } // TODO: This is a linear reduction, not a tree reduction. We can do // better: a tree reduction would make more computations happen in // parallel. Producer_t lastProducer = withinCoreReducedProducers.begin()->second; auto it = withinCoreReducedProducers.begin(); it++; while (it != withinCoreReducedProducers.end()) { Producer_t curProducer = it->second; shared_ptr core1; shared_ptr core2; Value core1Value; Value core2Value; auto lastProducerCoreId = lastProducer.core->coreId; auto curProducerCoreId = curProducer.core->coreId; assert(lastProducerCoreId != curProducerCoreId && "We should have already applied within-core reduction, how " "could we have same cores here?"); // Sort the cores by coreId if (curProducerCoreId < lastProducerCoreId) { core1 = curProducer.core; core1Value = curProducer.value; core2 = lastProducer.core; core2Value = lastProducer.value; } else { core1 = lastProducer.core; core1Value = lastProducer.value; core2 = curProducer.core; core2Value = curProducer.value; } auto newCoreRes = core1->makeResultRemappable(core1Value); auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); rewriter.setInsertionPointAfterValue(core2Value); Value vaddRes = rewriter.create( core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); lastProducer = {vaddRes, core2}; it++; } // TODO: Add the bias and apply mapping (if present) // Use last producer as the final result auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); outputTiles[outTile][out_x][out_y] = reducedValue; } } } // Now, we need to turn the cores into a spatial::SpatWeightedCompute. rewriter.setInsertionPointAfter(conv); spatial::SpatWeightedCompute lastWComputeOp; for (auto& core : cores) { lastWComputeOp = core->createWComputeOp(loc); core->remapResults(); rewriter.setInsertionPointAfter(lastWComputeOp); } for (auto& core : cores) core->addRemappedOperands(); // Set the insertion point after the last WComputeOp. rewriter.setInsertionPointAfter(lastWComputeOp); SmallVector tilesToConcat; tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); for (size_t outX = 0; outX < output_h; outX++) for (size_t outY = 0; outY < output_w; outY++) for (size_t outTile = 0; outTile < outputTileCount; outTile++) tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); Value outputImage = rewriter.create(loc, conv.getY().getType(), tilesToConcat); // Value outputImage = // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); // If no mapping (activation) was applied, just replace ConvOp // if (mapOperation == MapOperations::None) { // rewriter.replaceOp(conv, outputImage); // } else { // // If mapping was applied, erase ConvOp and replace the mapping op // rewriter.eraseOp(conv); // rewriter.replaceOp(firstUserOp, outputImage); // } return success(); } }; void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir