#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Value.h" #include "llvm/Support/raw_ostream.h" #include #include #include #include "SpatialReducer.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum) #define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum) namespace onnx_mlir { llvm::SmallPtrSet onnx_mlir::SpatialReducer::oldComputeOpsReplaced; ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum, std::function processFun, mlir::ConversionPatternRewriter& rewriter) { assert(processFun); auto computeOp = GET_COMP(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum); spatial::SpatYieldOp yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); mlir::Value result = yieldOp->getOperand(resultNum); rewriter.setInsertionPointAfterValue(result); mlir::Value processedResult = processFun(result); if (processedResult == result) { // Sometimes we want processedResult to return the same value but do // something else with it (e.g. in softmax we want to broadcast the value // using a channel). In this case, we can just return the same value. return resultNum; } yieldOp->insertOperands(yieldOp->getNumOperands(), processedResult); return yieldOp.getNumOperands() - 1; } OpAndResNum SpatialReducer::applyReducePattern(llvm::SmallVector& computeOpsAndResNum, std::function reduce, std::function preprocess, std::function postprocess) { if (preprocess) for (auto& computeOpAndResNum : computeOpsAndResNum) GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter); // It is possible that `computeOpsAndResNum` contains two entries for the same // computeOp. In this case, we need to apply the reduction within-computef // Keep a map between a computeOp and the last Value for this reduction std::unordered_map lastValueForCompute; for (auto& computeOpAndResNum : computeOpsAndResNum) { auto computeOp = GET_COMP(computeOpAndResNum); auto yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); auto it = lastValueForCompute.find(computeOp.getOperation()); if (it != lastValueForCompute.end()) { // If we have already seen this computeOp, apply the reduction // within-compute mlir::Value lastWithinComputeValue = it->second; assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp()); if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp())) rewriter.setInsertionPointAfterValue(lastWithinComputeValue); else rewriter.setInsertionPointAfterValue(valueWithinCompute); valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute); lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; } lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; } // Now, reconstruct from the map the computeOpsAndResNum list computeOpsAndResNum.clear(); computeOpsAndResNum.reserve(lastValueForCompute.size()); for (auto& entry : lastValueForCompute) { auto computeOp = mlir::cast(entry.first); auto valueWithinCompute = entry.second; // We check if `valueWithinCompute` is already used by the yieldOp, in that // case no need to add it auto yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); bool yieldOpUseFound = false; for (auto& use : valueWithinCompute.getUses()) { if (use.getOwner() == yieldOp.getOperation()) { // If the value is already used by the yieldOp, we can just use it computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()}); yieldOpUseFound = true; break; } } if (yieldOpUseFound) continue; // If this result is not used within a yieldOp, then add it auto resultNum = yieldOp->getNumOperands(); yieldOp->insertOperands(resultNum, valueWithinCompute); computeOpsAndResNum.push_back({computeOp, resultNum}); } mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); // Recursive algorithm to reduce the inputs to a single one: // - Take two inputs at a time, and reduce them into a single one, updating // the computeOpsAndResNum list which becomes half the size. // - Repeat until there is only one input left. llvm::OwningArrayRef computeOpsRef(computeOpsAndResNum); while (computeOpsRef.size() > 1) { llvm::SmallVector nextComputeOps; nextComputeOps.reserve(computeOpsRef.size() / 2); for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) { auto [firstCompute, firstResultNum] = computeOpsRef[i]; auto [secondCompute, secondResultNum] = computeOpsRef[i + 1]; if (secondCompute->isBeforeInBlock(firstCompute)) { std::swap(firstCompute, secondCompute); std::swap(firstResultNum, secondResultNum); } // We do not immediately alter the computeOps results/operands, instead we // do it in a delayed manner, to avoid invalidating the references to the // computeOps (which must be replaced by a cloned ComputeOp when changing // the number of results) // See below `reducerChanges.push_back` and `finalizeReduceUpdates` auto yieldOpFirstCompute = mlir::cast(firstCompute.getBody().front().getTerminator()); // Add a new operand to the block of the second computeOp mlir::Block& secondBlock = secondCompute.getBody().front(); mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); auto secondComputeWeightsNum = secondCompute->getAttrOfType(secondCompute.getOperandSegmentSizesAttrName())[0]; auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1; // Take the "former-result" from the second computeOp spatial::SpatYieldOp secondYield = mlir::cast(secondBlock.getTerminator()); mlir::Value formerRes2 = secondYield.getOperand(secondResultNum); // Apply reduction operation rewriter.setInsertionPoint(secondYield); mlir::Value reduced = reduce(formerRes2, formerRes1); // Unfortunately, it is not possible to update the result in place, // because we may have already referenced it by // outside of this function, thus replacing it would invalidate the // reference. Therefore, we need to append a new result to the yieldOp, // and then at a later stage update the computeOp accordingly. // Add `reduced` to the second yieldOp auto secondYieldOperandNum = secondYield.getNumOperands(); secondYield->insertOperands(secondYieldOperandNum, reduced); secondResultNum = secondYieldOperandNum; // We should also add an entry for updating the results of the last // operation (the one which never becomes a `firstCompute`): because it is // not tracked by reducerChanges as `fromOp` reducerChanges.push_back( {firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum}); nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum)); } // If we have an odd number of inputs, we need to add the last one to the // newInputs list. if (computeOpsRef.size() % 2 == 1) nextComputeOps.push_back(computeOpsRef.back()); // Replace the inputOps list with the new one. computeOpsRef = llvm::OwningArrayRef(std::move(nextComputeOps)); } assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point."); auto finalComputeAndResNum = computeOpsRef[0]; // Force the update of the results of this computeOp, when finalizing computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum)); if (postprocess) GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter); return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum)); } void SpatialReducer::finalizeReduceUpdates() { assert(reducesFinalized == false && "Cannot finalize two times."); reducesFinalized = true; // First, add the results to the computeOps for (auto& reduceChange : reducerChanges) updateResultsOfCompute(reduceChange.fromOp); for (auto& c : computeOpNeedingResUpdate) updateResultsOfCompute(c.getOperation()); for (auto& reducerChange : this->reducerChanges) { auto fromOp = reducerChange.fromOp; auto toOp = reducerChange.toOp; auto fromOpResNum = reducerChange.fromOpResNum; auto toOpOperandNum = reducerChange.toOpOperandNum; auto fromComputeOp = opToReplacedCompute[fromOp]; assert(fromComputeOp && "fromOp should have been mapped before!"); // toComputeOp could be the existing pointer, or we have to remap it with // `opToReplacedCompute` auto toComputeOp = opToReplacedCompute[toOp]; if (!toComputeOp) toComputeOp = mlir::cast(toOp); assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!"); assert(toComputeOp->getNumOperands() == toOpOperandNum && "toOpOperandNum should be the last operand of toComputeOp, are the " "operations in the right order?"); // Add the new operand to `toComputeOp` auto fromResult = fromComputeOp.getResult(fromOpResNum); toComputeOp->insertOperands(toOpOperandNum, fromResult); incrementWeightedComputeInputsSegmentSize(toComputeOp, 1); } } mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates."); mlir::Operation* opToCast; auto it = opToReplacedCompute.find(opAndResNum.first); if (it != opToReplacedCompute.end()) opToCast = it->second; else opToCast = opAndResNum.first; auto computeOp = mlir::cast(opToCast); return computeOp.getResult(opAndResNum.second); } void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { // If we have already replaced the fromOp, we do not need to do it again return; } auto oldComputeOp = mlir::cast(computeOp); auto oldComputeOpNum = oldComputeOp->getNumOperands(); auto yieldOp = mlir::cast(oldComputeOp.getBody().front().getTerminator()); if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { // No result was added, just add itself to the map opToReplacedCompute[oldComputeOp.getOperation()] = oldComputeOp; return; } // Add the results by inspecting its YieldOp auto newResultTypes = yieldOp.getOperandTypes(); // Create a new ComputeOp with the new result type, but same operands rewriter.setInsertionPoint(oldComputeOp); auto newComputeOp = rewriter.create( oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); newComputeOp.getBody().takeBody(oldComputeOp.getBody()); auto newComputeOpNum = newComputeOp->getNumOperands(); assert(oldComputeOpNum == newComputeOpNum); // Since we replaced the old ComputeOp with a new one, we need to replace // all its results' uses for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) { mlir::Value oldResult = oldComputeOp.getResult(i); mlir::Value newResult = newComputeOp.getResult(i); // Replace the uses, except the uses of the compute ops which got deleted // previously rewriter.replaceAllUsesExcept(oldResult, newResult, oldComputeOpsReplaced); } // Finally, erase the old computeOp and update the map opToReplacedCompute[oldComputeOp.getOperation()] = newComputeOp; oldComputeOpsReplaced.insert(oldComputeOp.getOperation()); rewriter.setInsertionPoint(oldComputeOp); rewriter.eraseOp(oldComputeOp); } mlir::Value SpatialReducer::createImgConcatOp(llvm::SmallVector>>& outputTiles, mlir::Location& loc, mlir::Type outputType) { assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates."); // outputTiles are indexed like this: [channelTile][x][y] auto tilesCount = outputTiles.size(); auto width = outputTiles[0].size(); auto height = outputTiles[0][0].size(); llvm::SmallVector>> remappedOutputTiles( tilesCount, llvm::SmallVector>(width, llvm::SmallVector(height))); for (size_t t = 0; t < tilesCount; t++) for (size_t x = 0; x < width; x++) for (size_t y = 0; y < height; y++) remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]); return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType); } OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector& computeOps, mlir::ConversionPatternRewriter& rewriter, mlir::Value biasTile, MapOperations mapOp) { std::function postprocessing = nullptr; if (mapOp != MapOperations::None) { postprocessing = [&](const mlir::Value a) { mlir::Value mapOperand = a; if (biasTile) mapOperand = rewriter.create(a.getLoc(), a.getType(), a, biasTile); return createMapOperation(rewriter, mapOp, mapOperand); }; } return this->applyReducePattern( computeOps, [&](mlir::Value a, mlir::Value b) { return rewriter.create(a.getLoc(), a.getType(), a, b); }, /* preprocess = */ nullptr, postprocessing); } } // namespace onnx_mlir