diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index d909e05..e443e17 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -20,6 +20,8 @@ add_onnx_mlir_library(OMPIMAccel Pass/CountInstructionPass.cpp Pass/EmitPimJsonPass.cpp Pass/MessagePass.cpp + Pass/PimFoldHostConstantsPass.cpp + Pass/PimHostVerificationPass.cpp EXCLUDE_FROM_OM_LIBS @@ -43,4 +45,5 @@ add_onnx_mlir_library(OMPIMAccel OMSpatialToGraphviz OMSpatialToPIM OMPIMCommon -) \ No newline at end of file + MLIRTensorInferTypeOpInterfaceImpl +) diff --git a/src/PIM/Common/PIMCommon.cpp b/src/PIM/Common/PIMCommon.cpp index ada0745..f648c57 100644 --- a/src/PIM/Common/PIMCommon.cpp +++ b/src/PIM/Common/PIMCommon.cpp @@ -5,6 +5,7 @@ #include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Compiler/CompilerOptions.hpp" using namespace mlir; @@ -30,6 +31,60 @@ void dumpModule(ModuleOp moduleOp, const std::string& name) { file.close(); } +FailureOr getPimEntryFunc(ModuleOp moduleOp) { + if (!moduleOp) + return failure(); + + SmallVector entryPoints(moduleOp.getOps()); + if (entryPoints.size() > 1) { + moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") + << entryPoints.size(); + return failure(); + } + if (!entryPoints.empty()) { + auto entryPointAttr = + entryPoints.front()->getAttrOfType(ONNXEntryPointOp::getEntryPointFuncAttrName()); + if (!entryPointAttr) { + entryPoints.front().emitOpError("is missing the entry point function attribute"); + return failure(); + } + auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); + if (!entryFunc) { + entryPoints.front().emitOpError("references an unknown entry function ") + << entryPointAttr.getLeafReference().getValue(); + return failure(); + } + return entryFunc; + } + + if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) + return mainGraphFunc; + + SmallVector nonExternalFuncs; + for (auto funcOp : moduleOp.getOps()) { + if (!funcOp.isExternal()) + nonExternalFuncs.push_back(funcOp); + } + if (nonExternalFuncs.size() == 1) + return nonExternalFuncs.front(); + + moduleOp.emitError("could not resolve a unique PIM entry function"); + return failure(); +} + +bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; } + +void markWeightAlways(Operation* op) { + assert(op && "expected valid op"); + op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext())); +} + +memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { + if (!moduleOp || !getGlobalOp) + return {}; + return moduleOp.lookupSymbol(getGlobalOp.getName()); +} + FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { auto channelNewOp = op->getOperand(0).getDefiningOp(); diff --git a/src/PIM/Common/PIMCommon.hpp b/src/PIM/Common/PIMCommon.hpp index 530b973..153d89a 100644 --- a/src/PIM/Common/PIMCommon.hpp +++ b/src/PIM/Common/PIMCommon.hpp @@ -1,6 +1,8 @@ #pragma once +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Operation.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -9,6 +11,7 @@ #include "src/Compiler/CompilerOptions.hpp" const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate"; +inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways"; namespace onnx_mlir { @@ -18,6 +21,14 @@ void createDirectory(const std::string& directory); void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); +llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp); + +bool hasWeightAlways(mlir::Operation* op); + +void markWeightAlways(mlir::Operation* op); + +mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); + llvm::FailureOr getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 362403d..a89eae1 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -13,11 +13,12 @@ #include #include +#include "Common/PIMCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerUtils.hpp" @@ -49,8 +50,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { // Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others SmallDenseMap globalConstants; funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - if (!getGlobalOp->hasAttr("weightAlways")) { - auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + if (!hasWeightAlways(getGlobalOp)) { + auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto iter = globalConstants.find(globalMemrefOp); if (iter == globalConstants.end()) globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); @@ -81,7 +82,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const { return iter->second; } -PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { +PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { return deviceMem.try_emplace(id, memEntriesMap).first->second; } @@ -112,10 +113,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { } value = source; } + else if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + } + else if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + } + else if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + } else break; } - return memEntriesMap.at(value).address + offset; + + auto iter = memEntriesMap.find(value); + if (iter == memEntriesMap.end()) { + errs() << "Missing mem entry for value: "; + value.print(errs()); + errs() << "\n"; + if (auto* definingOp = value.getDefiningOp()) { + errs() << "Defining op:\n"; + definingOp->print(errs()); + errs() << "\n"; + } + llvm_unreachable("Missing mem entry"); + } + + return iter->second.address + offset; } json::Object PimCodeGen::createEmptyOffset() { @@ -348,6 +372,55 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co } } +void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { + auto srcAddr = memory.getValueAddress(transposeOp.getData()); + auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf()); + + auto srcType = cast(transposeOp.getData().getType()); + auto srcShape = srcType.getShape(); + size_t rank = srcShape.size(); + size_t elementSize = srcType.getElementTypeBitWidth() / 8; + size_t totalElements = srcType.getNumElements(); + + // Read permutation and compute its inverse + SmallVector perm = + map_to_vector(transposeOp.getPerms().getAsRange(), [](auto attr) -> int64_t { return attr.getInt(); }); + SmallVector permInv(rank); + for (size_t i = 0; i < rank; i++) + permInv[perm[i]] = i; + + // Destination shape: dstShape[i] = srcShape[perm[i]] + SmallVector dstShape(rank); + for (size_t i = 0; i < rank; i++) + dstShape[i] = srcShape[perm[i]]; + + // Row-major strides for source and destination + SmallVector srcStrides(rank, 1); + SmallVector dstStrides(rank, 1); + for (int64_t i = rank - 2; i >= 0; i--) { + srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1]; + dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1]; + } + + // Emit element-by-element copy with transposed addressing + for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) { + // Decompose flat source index into multi-dimensional index + SmallVector srcIdx(rank); + size_t remaining = srcFlat; + for (size_t d = 0; d < rank; d++) { + srcIdx[d] = remaining / srcStrides[d]; + remaining %= srcStrides[d]; + } + + // Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]] + size_t dstFlat = 0; + for (size_t d = 0; d < rank; d++) + dstFlat += srcIdx[permInv[d]] * dstStrides[d]; + + emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len"); + } +} + size_t getMatrixSize(ShapedType matrixShape) { if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4) assert(false && "Unsupported matrix shape"); @@ -378,9 +451,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& std::vector memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - if (getGlobalOp->hasAttr("weightAlways")) + if (hasWeightAlways(getGlobalOp)) return; - auto globalOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) return; auto initialValue = globalOp.getInitialValue(); @@ -416,7 +489,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { size_t processedOperations = 0; for (auto& op : coreOp.getBody().front()) { - if (isa(op)) + if (isa(op)) continue; if (auto loadOp = dyn_cast(op)) @@ -435,6 +508,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false); else if (auto applyFiltersOp = dyn_cast(op)) coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp); + else if (auto transposeOp = dyn_cast(op)) + coreCodeGen.codeGenTransposeOp(transposeOp); else if (auto vaddOp = dyn_cast(op)) coreCodeGen.codeGenVAddOp(vaddOp); else if (auto vmaxOp = dyn_cast(op)) @@ -475,7 +550,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp, continue; } - auto globalOp = SymbolTable::lookupNearestSymbolFrom(moduleOp, getGlobalOp.getNameAttr()); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) { coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex)); weightIndex++; @@ -589,9 +664,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } } - auto funcOps = moduleOp.getOps(); - assert(!funcOps.empty() && "No function found in the module"); - auto funcOp = *funcOps.begin(); + auto entryFunc = getPimEntryFunc(moduleOp); + if (failed(entryFunc)) + return CompilerFailure; + auto funcOp = *entryFunc; PimAcceleratorMemory memory; memory.hostMem.allocateHost(moduleOp, funcOp); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 24de8c7..7d520b1 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -5,7 +5,7 @@ #include "Common/ValueMap.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" namespace onnx_mlir { @@ -49,7 +49,7 @@ public: PimAcceleratorMemory() : hostMem(memEntriesMap) {} - PimMemory getOrCreateDeviceMem(size_t id); + PimMemory& getOrCreateDeviceMem(size_t id); size_t getValueAddress(mlir::Value value) const; }; @@ -95,6 +95,7 @@ public: void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const; void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; + void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; }; OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index cd95654..ab79c1f 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -25,7 +25,6 @@ extern llvm::cl::opt pimEmissionTarget; extern llvm::cl::opt pimOnlyCodegen; extern llvm::cl::opt useExperimentalConvImpl; -extern llvm::cl::opt exportCrossbarWeights; extern llvm::cl::opt crossbarSize; extern llvm::cl::opt crossbarCountInCore; diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index c41b6ad..69199a2 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -2,7 +2,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerUtils.hpp" @@ -46,6 +46,10 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPimCodegen) { + pm.addPass(createPimFoldHostConstantsPass()); + pm.addPass(createMessagePass("Pim host constants folded")); + pm.addPass(createPimHostVerificationPass()); + pm.addPass(createMessagePass("Pim host verified")); pm.addPass(createEmitPimJsonPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim json code emitted")); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 7e0ad60..e105edd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,21 +3,15 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_onnx_mlir_library(OMONNXToSpatial - Math/Gemm.hpp Math/Gemm.cpp - Math/Conv.hpp Math/Conv.cpp - Math/ExperimentalConv.cpp - Math/ExperimentalGemm.cpp NN/Pooling.cpp - NN/ExperimentalPooling.cpp NN/ReduceMean.cpp Tensor/ONNXConcatToTensorConcat.cpp Tensor/RemoveUnusedHelperOps.cpp Utils/SpatialReducer.cpp Utils/WeightSubdivider.cpp Utils/AnnotateReplication.cpp - ONNXToSpatialPass.hpp ONNXToSpatialPass.cpp ONNXToSpatialCommon.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index b34aebb..4c5b1ac 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -242,6 +242,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return success(); } -void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } +void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp index 1ef9566..b29ace7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp @@ -18,6 +18,6 @@ struct ConvToGemm : mlir::OpConversionPattern { mlir::ConversionPatternRewriter& rewriter) const override; }; -void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp deleted file mode 100644 index 990bfd7..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Math/ConvOld.cpp +++ /dev/null @@ -1,583 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" - -#include -#include -#include -#include - -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; -using namespace std; - -namespace onnx_mlir { - -// NOTE: -// This might be useful to re-implement this considering for loops. -// neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; - -/** - * @brief A momentary representation of a core, to be used within the tiling of - * a convolution operation. - */ -class Core { -public: - Core(const size_t coreId, ConversionPatternRewriter& rewriter) - : coreId(coreId), rewriter(rewriter) {} - - /** - * @brief Add a MVM operation to the core. - * - * @param inputTile The input tile to the MVM operation. - * @param xbarIndex The index of the crossbar weight to use. - * @param outputTileId The id of the output tile. - * @param mvmOutType The result's shape. - * @return Value The result of the MVM operation. - */ - Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { - // Use the inputTile as the reference location for the MVM operation. - Location loc = inputTile.getLoc(); - - // Move the insertion point to the end of the block. - rewriter.setInsertionPointToEnd(block.get()); - - // Add the inputTile to the block arguments, and to the operands. - Value operand = operandMap.lookupOrNull(inputTile); - if (not operand) { - operand = block->addArgument(inputTile.getType(), loc); - operands.push_back(inputTile); - operandMap.map(inputTile, operand); - } - - // TODO: Compute the output type using the matrix, and check if `mvmOutType` - // is correct. - - // Construct the MVM operation - Value result = rewriter.create(loc, mvmOutType, xbarIndex, operand); - - // Since we are within the same core and no computation can happen in - // paralllel, we can just apply a linear reduction in case we have multiple - // MVM operations for the same outputTile. - auto lastMVM = outputTileToMVM.find(outputTileId); - - // If an entry for this outputTile already exists, apply reduction. - if (lastMVM != outputTileToMVM.end()) { - // MVM results should have the same type for reduction. - assert(lastMVM->second.getType() == result.getType()); - result = rewriter.create(loc, mvmOutType, lastMVM->second, result); - } - - outputTileToMVM[outputTileId] = result; - return result; - } - - /** - * @brief Mark a result as remappable, and return a shared pointer to it. - * - * This function marks a result as remappable, and returns a shared pointer to - * it. We need to keep track of these values to generate the YieldOp at a - * later stage. - * - * @param result A result to track, for later remapping. - * @return shared_ptr A shared pointer to the result. - */ - shared_ptr makeResultRemappable(Value result) { - // Verify that the result is present in the block. - assert(result.getDefiningOp()->getBlock() == block.get()); - - shared_ptr remappableResult = make_shared(result); - - resultsToRemap.push_back(remappableResult); - results.push_back(result); - - return remappableResult; - } - - /** - * @brief Add a remappable operand to the core, to merge partial results - * inter-core. - * - * @param remappableOperand The operand to add. - * @return Value The block argument representing the operand. - */ - Value addRemappableOperand(std::shared_ptr operand) { - // Check that the operand is not already there. - assert(not operandMap.contains(*operand)); - - Value argument = block->addArgument(operand->getType(), operand->getLoc()); - remappableOperands.push_back(operand); - return argument; - } - - /** - * @brief Generate a spatial::SpatWeightedCompute operation from the core. - * - * @param loc The location of the operation. - * @return spatial::SpatWeightedCompute - */ - spatial::SpatWeightedCompute createWComputeOp(Location loc) { - // Get the shape of the results. - SmallVector resultTypes; - for (const auto& value : results) - resultTypes.push_back(value.getType()); - - // Create the WComputeOp, with non-remappable operands only. - wcomputeOp = rewriter.create(loc, resultTypes, xbarWeights, operands); - - // Add the body to the WComputeOp. - Block* releasedBlock = block.release(); - wcomputeOp.getBody().push_back(releasedBlock); - - // Add the `yieldOp` at the end, with the results. - rewriter.setInsertionPointToEnd(releasedBlock); - rewriter.create(loc, results); - - return wcomputeOp; - } - - /** - * @brief Remap the results to the WComputeOp results. - */ - void remapResults() { - // Remap all the results to the WComputeOp results. - assert(resultsToRemap.size() == wcomputeOp->getNumResults()); - for (size_t i = 0; i < resultsToRemap.size(); i++) - *resultsToRemap[i] = wcomputeOp.getResult(i); - } - - void addRemappedOperands() { - // Insert the remappableOperands (which were remapped in - // `addRemappableOperand` of another Core) - for (auto remappedValue : remappableOperands) - wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); - - // Update the wcomputeOp operandSegmentSize - incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast(remappableOperands.size())); - } - - size_t addXbarWeight(Value weight) { - assert(!isXbarsFull()); - xbarWeights.push_back(weight); - return xbarWeights.size() - 1; - } - - bool isXbarsFull() { - assert(xbarWeights.size() <= crossbarCountInCore); - return xbarWeights.size() == crossbarCountInCore; - } - - bool isCoreEmpty() { return block->empty(); } - - void dump() { - // Print the coreId - llvm::outs() << "Core " << coreId << ":\n"; - // Print the weights - llvm::outs() << "Xbar Weights:\n"; - for (auto weight : xbarWeights) - weight.dump(); - // Print the operands - llvm::outs() << "Operands:\n"; - for (auto operand : operands) - llvm::outs() << operand << "\n"; - - // Dump the body block - for (auto& op : block->getOperations()) - op.dump(); - - // Print the results - llvm::outs() << "Results:\n"; - for (auto result : results) - llvm::outs() << result << "\n"; - } - - const size_t coreId; - -private: - ConversionPatternRewriter& rewriter; - - // Should these be set instead? But I need to keep the order - vector operands; - vector> remappableOperands; - - vector results; - vector> resultsToRemap; - - // Maps from input tiles to the block operand - IRMapping operandMap; - - // Map from outputTileId to MVM operation producing it - unordered_map outputTileToMVM; - - vector xbarWeights; - - unique_ptr block = make_unique(); - - spatial::SpatWeightedCompute wcomputeOp; -}; - -struct ConvToManyGemms : public OpConversionPattern { - ConvToManyGemms(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - struct Producer_t { - Value value; - shared_ptr core; - }; - - LogicalResult - matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); - ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); - ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); - ShapedType yShape = mlir::cast(conv.getY().getType()); - - size_t stride_x, stride_y, dilation_x, dilation_y, pad_x, pad_y; - unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); - unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); - - auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) - return rewriter.notifyMatchFailure(conv, padUnpackError.value()); - - // TODO: Pad value at beginning and end of each dimension could be - // different. We should handle this case. - - // MapOperations mapOperation = MapOperations::None; - // - // // If we have just one user, and it is an activation funcion (or more in - // // general a mapping operation) just inline it in the computeOps - // auto firstUserOp = *conv->getUsers().begin(); - // if (conv->hasOneUse()) { - // mapOperation = mlirOpToMapOperationEnum(firstUserOp); - // - // if (mapOperation == MapOperations::ONNXSoftmaxOp) { - // return rewriter.notifyMatchFailure( - // conv, "Softmax not supported as activation for convolutions."); - // } - // } - - size_t input_h = GET_IMAGE_HEIGHT(xShape); - size_t input_w = GET_IMAGE_WIDTH(xShape); - size_t output_h = GET_IMAGE_HEIGHT(yShape); - size_t output_w = GET_IMAGE_WIDTH(yShape); - size_t krn_h = GET_KERNEL_HEIGHT(wShape); - size_t krn_w = GET_KERNEL_WIDTH(wShape); - - Location loc = conv.getLoc(); - - size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); - size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; - size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); - size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; - - // Tile the input tensor - // Input tiles need to be indexed by: - // a. Channel Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: inputTiles[channelTile][x][y] - // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) - SmallVector>> inputTiles( - inputTileCount, SmallVector>(input_w, SmallVector(input_h))); - - auto resolveErrorOpt = resolveImgInputTiles( - convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter); - if (resolveErrorOpt.has_value()) - return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); - - SmallVector strides = SmallVector(4, rewriter.getIndexAttr(1)); - SmallVector offsets = SmallVector(4, rewriter.getIndexAttr(0)); - SmallVector sizes = SmallVector {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - - // Tile the weight tensor - // Weight tiles need to be indexed by: - // a. Filter Tile - // b. Channel Tile - // c. Kernel `x` position - // d. Kernel `y` position - // For example: weightTiles[filterTile][channelTile][x][y] - // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) - SmallVector>>> weightTiles( - outputTileCount, - SmallVector>>(inputTileCount, - SmallVector>(krn_w, SmallVector(krn_h)))); - strides = SmallVector(4, rewriter.getIndexAttr(1)); - offsets = SmallVector(4, rewriter.getIndexAttr(0)); - sizes = {rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - for (size_t i = 0; i < outputTileCount; i++) { - if (i == outputTileCount - 1 && outputTileRemainder != 0) - sizes[0] = rewriter.getIndexAttr(outputTileRemainder); - sizes[1] = rewriter.getIndexAttr(crossbarSize); - offsets[0] = rewriter.getIndexAttr(i * crossbarSize); - for (size_t j = 0; j < inputTileCount; j++) { - if (j == inputTileCount - 1 && inputTileRemainder != 0) - sizes[1] = rewriter.getIndexAttr(inputTileRemainder); - for (size_t x = 0; x < krn_w; x++) { - for (size_t y = 0; y < krn_h; y++) { - offsets[1] = rewriter.getIndexAttr(j * crossbarSize); - offsets[2] = rewriter.getIndexAttr(x); - offsets[3] = rewriter.getIndexAttr(y); - weightTiles[i][j][x][y] = - rewriter.create(loc, convAdaptor.getW(), offsets, sizes, strides); - } - } - } - } - - /* Distribute the computation among many compute cores - * Try to compute in-core the computation for each output tile, and reduce - * over as few cores as possible - */ - - // Tile the output tensor - // Output tiles need to be indexed by: - // a. Filter Tile - // b. Pixel `x` position - // c. Pixel `y` position - // For example: outputTiles[filterTile][x][y] - // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) - SmallVector>>> outputTiles( - outputTileCount, - SmallVector>>(output_w, SmallVector>(output_h, nullptr))); - - size_t replicationFactor; - if (!conv->hasAttr(REPLICATION_ATTR_NAME)) - replicationFactor = 1; - else - replicationFactor = conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); - // producers[outTile][out_x][out_y][producerIndex] - vector>>> producers = vector>>>( - outputTileCount, - vector>>(output_w, vector>(output_h, vector()))); - - // Schedule in cores - size_t coreId = 0; - vector> curCores(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) - curCores[i] = make_shared(coreId++, rewriter); - - vector> cores; - - const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor); - - for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { - for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { - - RankedTensorType mvmOutType = - RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, bShape.getElementType()); - - for (size_t outTile = 0; outTile < outputTileCount; outTile++) { - - if (outTile == outputTileCount - 1 && outputTileRemainder != 0) - mvmOutType = mvmOutType.clone({1, static_cast(outputTileRemainder), 1, 1}); - - for (size_t inTile = 0; inTile < inputTileCount; inTile++) { - - vector xbarIndexes(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) - xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]); - - size_t out_x = 0; - for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { - size_t out_y = 0; - - // I use `replicationFactor` cores. I divide the input_w into - // `replicationFactor` slices, and each slice is distributed to a - // core. `coreIndex` is the index of the core that will be used - // for this slice - size_t coreIndex = in_x / replicationSliceSize; - assert(coreIndex < replicationFactor); - - for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { - // Adjust the input based on the kernel - int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x; - int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y; - - // Check if we are within the input image - if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) { - out_y++; - continue; - } - - size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y; - auto mvm = curCores[coreIndex]->addMVM( - inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType); - - producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]}); - - out_y++; - } - out_x++; - } - - // Computations for these crossbars are done, check if the cores - // crossbars are fully used. If full, swap with new core - for (size_t i = 0; i < replicationFactor; i++) { - if (curCores[i]->isXbarsFull()) { - cores.emplace_back(std::move(curCores[i])); - curCores[i] = make_shared(coreId++, rewriter); - } - } - } - } - } - } - - for (auto& curCore : curCores) - if (curCore->isCoreEmpty() == false) - cores.emplace_back(std::move(curCore)); - curCores.clear(); - // Now, do the reduction of each output pixel tile - for (size_t outTile = 0; outTile < outputTileCount; outTile++) { - for (size_t out_x = 0; out_x < output_w; out_x++) { - for (size_t out_y = 0; out_y < output_h; out_y++) { - // First, check if some producers are within the same core. If this is - // true, `Core::addMVM` have already done the reduction within-core. - // This means that we only need to consider the last producer for that - // core. - - std::unordered_map withinCoreReducedProducers; - for (auto producer : producers[outTile][out_x][out_y]) - withinCoreReducedProducers[producer.core->coreId] = producer; - - // Now, we need to apply inter-core reduction - - // Base case with one producer - if (withinCoreReducedProducers.size() == 1) { - // TODO: Add the bias and apply mapping (if present) - - auto singleProducer = withinCoreReducedProducers.begin()->second; - // Use last producer as the final result - auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value); - outputTiles[outTile][out_x][out_y] = reducedValue; - continue; - } - - // TODO: This is a linear reduction, not a tree reduction. We can do - // better: a tree reduction would make more computations happen in - // parallel. - - Producer_t lastProducer = withinCoreReducedProducers.begin()->second; - - auto it = withinCoreReducedProducers.begin(); - it++; - while (it != withinCoreReducedProducers.end()) { - - Producer_t curProducer = it->second; - - shared_ptr core1; - shared_ptr core2; - Value core1Value; - Value core2Value; - - auto lastProducerCoreId = lastProducer.core->coreId; - auto curProducerCoreId = curProducer.core->coreId; - - assert(lastProducerCoreId != curProducerCoreId - && "We should have already applied within-core reduction, how " - "could we have same cores here?"); - - // Sort the cores by coreId - if (curProducerCoreId < lastProducerCoreId) { - core1 = curProducer.core; - core1Value = curProducer.value; - core2 = lastProducer.core; - core2Value = lastProducer.value; - } - else { - core1 = lastProducer.core; - core1Value = lastProducer.value; - core2 = curProducer.core; - core2Value = curProducer.value; - } - - auto newCoreRes = core1->makeResultRemappable(core1Value); - auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); - - rewriter.setInsertionPointAfterValue(core2Value); - Value vaddRes = rewriter.create( - core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); - - lastProducer = {vaddRes, core2}; - - it++; - } - - // TODO: Add the bias and apply mapping (if present) - - // Use last producer as the final result - auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); - outputTiles[outTile][out_x][out_y] = reducedValue; - } - } - } - - // Now, we need to turn the cores into a spatial::SpatWeightedCompute. - rewriter.setInsertionPointAfter(conv); - spatial::SpatWeightedCompute lastWComputeOp; - for (auto& core : cores) { - lastWComputeOp = core->createWComputeOp(loc); - core->remapResults(); - rewriter.setInsertionPointAfter(lastWComputeOp); - } - - for (auto& core : cores) - core->addRemappedOperands(); - - // Set the insertion point after the last WComputeOp. - rewriter.setInsertionPointAfter(lastWComputeOp); - SmallVector tilesToConcat; - tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); - for (size_t outX = 0; outX < output_h; outX++) - for (size_t outY = 0; outY < output_w; outY++) - for (size_t outTile = 0; outTile < outputTileCount; outTile++) - tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); - - Value outputImage = rewriter.create(loc, conv.getY().getType(), tilesToConcat); - - // Value outputImage = - // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); - - // If no mapping (activation) was applied, just replace ConvOp - // if (mapOperation == MapOperations::None) { - // rewriter.replaceOp(conv, outputImage); - // } else { - // // If mapping was applied, erase ConvOp and replace the mapping op - // rewriter.eraseOp(conv); - // rewriter.replaceOp(firstUserOp, outputImage); - // } - - return success(); - } -}; - -void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp deleted file mode 100644 index fae01b7..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp +++ /dev/null @@ -1,400 +0,0 @@ -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/SmallVector.h" - -#include -#include -#include - -#include "Compiler/PimCompilerOptions.hpp" -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; -using namespace std; - -namespace onnx_mlir { - -/** - * @brief A pattern to tile the convolution operation into a series of compute - * units, each one of which applies filters to a subset of the input - * tensor. Results are also reduced and concatenated to form the final - * output tensor. - */ -struct ExperimentalONNXConvOpTile : public OpConversionPattern { - ExperimentalONNXConvOpTile(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - LogicalResult - matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { - - // --------------------------------- // - // --- READ OPERATION PARAMETERS --- // - // --------------------------------- // - - // To get each crossbar's weights, we need to slice the weights tensor. - // - Along the input tiles. - // - Along the output tiles. - // - Along the filter x position. - // - Along the filter y position. - ShapedType inputType = cast(convAdaptor.getX().getType()); - ShapedType outputType = cast(conv.getY().getType()); - ShapedType weightsType = cast(convAdaptor.getW().getType()); - - // TODO: Address bigger batches. - assert(GET_IMAGE_N(inputType) == 1 - && "Batch size must be 1" - "for convolution."); - - // TODO: Address replication. - assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution."); - - // TODO: Address bias addition. - - ldiv_t inputTileCount = div(GET_IMAGE_CHANNEL(inputType), crossbarSize); - ldiv_t outputTileCount = div(GET_IMAGE_CHANNEL(outputType), crossbarSize); - size_t kernelWidth = GET_KERNEL_WIDTH(weightsType); - size_t kernelHeight = GET_KERNEL_HEIGHT(weightsType); - - // Assert that the kernel is square. - assert(kernelWidth == kernelHeight && "Only square kernels are supported."); - - // -------------------------------- // - // --- SLICE THE WEIGHTS TENSOR --- // - // -------------------------------- // - - // The core idea of this stage is classifying the weights by input and - // output tile. This is because we want the applyFilters operations to be - // tile agnostic, to keep the subsequent lowering stages as simple as - // possible. This data structure does this weight classification: - // - The outer map is indexed by input tile. - // - The inner map is indexed by output tile. - // - The SmallVector contains the weights for the filter. - map>> weightsGroups; - - // During all slicing operations within this stage, we'll use the same - // strides for all dimensions. - SmallVector slicingStrides(4, rewriter.getIndexAttr(1)); - - ldiv_t itc = inputTileCount; - ldiv_t otc = outputTileCount; - - // - Slicing along the input tiles. - // - Slicing along the output tiles. - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize; - for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) { - long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize; - - // The loop above also sets the crossbar's used width and height, - // checking if we're at the last crossbar and if it's incomplete. - - long outputTile = ot; - long inputTile = it; - - // Create the slicing sizes. - SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight), - /* 1 */ rewriter.getIndexAttr(crossbarWidth), - /* 2 */ rewriter.getIndexAttr(1), - /* 3 */ rewriter.getIndexAttr(1)}; - - // - Slicing along the filter x position. - // - Slicing along the filter y position. - for (size_t filterX = 0; filterX < kernelWidth; ++filterX) { - for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { - - // Create the slicing offsets. - SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), - /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), - /* 2 */ rewriter.getIndexAttr(filterX), - /* 3 */ rewriter.getIndexAttr(filterY)}; - - // Create the slice extraction operation. - auto extractSliceOp = rewriter.create( - conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides); - - // Add a note to the extractSliceOp, with the filterX and filterY. - weightsGroups[inputTile][outputTile].push_back(extractSliceOp); - } - } - } - } - - // TODO: Tree reduction for compute reduction should be implemented. - - // -------------------------------- // - // --- CREATE ALL COMPUTE UNITS --- // - // -------------------------------- // - - // Keep track of input slicing operations to avoid duplication across - // all compute units (global slices). - map globalSlices; - - // Keep track of all partial compute results. - map globalPartialResults; - - // Use a weight subdivider to extract groups of weights for each compute - // unit. We'll keep extracting groups until no more weights are left. - WeightSubdivider weightSubdivider(weightsGroups); - while (!weightSubdivider.isEmpty()) { - - // -------------------------------- // - // --- BEGIN A NEW COMPUTE UNIT --- // - // -------------------------------- // - - // Get the next group of weights for the compute unit. - SmallVector weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue()); - - SmallVector computeWeights; - SmallVector computeOperands; - - // ------------------------------ // - // --- SLICE THE INPUT TENSOR --- // - // ------------------------------ // - - // Note each tile's index in the compute unit arguments. - map inputTileIndices; - map outputTileIndices; - map reductionTileIndices; // Incoming partial results. - - // Iterate over all weights groups for this compute unit. - map localSlices; // WRT the current compute unit. - for (auto group : weightsGroups) { - for (Value weight : group.weights) - computeWeights.push_back(weight); - - // There might be multiple weight groups for the same input tile, so if - // we've already added the input tile, skip it. - if (localSlices.find(group.inputTile) != localSlices.end()) - continue; - - // We might have already sliced the input tensor for some other compute - // unit, so if we have, reuse the slicing operation without creating a - // new one. - if (globalSlices.find(group.inputTile) != globalSlices.end()) { - computeOperands.push_back(globalSlices[group.inputTile]); - localSlices[group.inputTile] = globalSlices[group.inputTile]; - continue; - } - - // Create the input tensor slicing offsets. - SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. - /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), - /* 2 */ rewriter.getIndexAttr(0), - /* 3 */ rewriter.getIndexAttr(0)}; - - // Create the input tensor slicing sizes. - size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize; - SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), - /* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))}; - - // Create the slice extraction operation. - auto extractSliceOp = rewriter.create( - conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides); - - computeOperands.push_back(extractSliceOp); - - // Update slicing maps. - globalSlices[group.inputTile] = extractSliceOp; - localSlices[group.inputTile] = extractSliceOp; - - // Update the input tile index. - inputTileIndices[group.inputTile] = computeOperands.size() - 1; - } - - // ------------------------------- // - // --- PREPARE THE OUTPUT TYPE --- // - // ------------------------------- // - - // Fill the compute output's type by looking at the output tiles. - SmallVector computeOutputType; - for (TaggedWeights group : weightsGroups) { - - // There might be multiple weight groups for the same output tile, so if - // we've already added the output tile, skip it. - if (outputTileIndices.find(group.outputTile) != outputTileIndices.end()) - continue; - - // Additionally, after adding the input slices as operands, also add any - // compatible partial results from previous compute units. - if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) { - computeOperands.push_back(globalPartialResults[group.outputTile]); - reductionTileIndices[group.outputTile] = computeOperands.size() - 1; - } - - // Define the output shape for this group. - long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize; - - // TODO: Address non-same padding. - SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. - /* 1 */ outputTileSize, - /* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed. - /* 3 */ GET_IMAGE_HEIGHT(outputType)}; - - auto elementType = dyn_cast(conv.getY().getType()).getElementType(); - - computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType)); - - outputTileIndices[group.outputTile] = computeOutputType.size() - 1; - } - - // ----------------------------- // - // --- FILL THE COMPUTE UNIT --- // - // ----------------------------- // - - // Create the compute unit. - spatial::SpatWeightedCompute currentCompute = rewriter.create( - conv.getLoc(), computeOutputType, computeWeights, computeOperands); - - // Create a new block for the compute unit and add the operands. - Block* block = rewriter.createBlock(¤tCompute.getRegion()); - rewriter.setInsertionPointToStart(block); - for (Value operand : computeOperands) - block->addArgument(operand.getType(), conv->getLoc()); - - // Initialize a map of local partial results. - map localPartialResults; // WRT the current compute unit. - - // If we have any reduction tiles, add them to the local partial results. - for (auto reductionTileIndex : reductionTileIndices) - localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second); - - // Add all the applyFilters operations to the block. - for (TaggedWeights group : weightsGroups) { - - // Get the outputType for this group. - Type outputType = computeOutputType[outputTileIndices[group.outputTile]]; - - // Create an apply filters operation. - BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]); - - // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, - // ... As many weights as the size of group.weights. - SmallVector weightIndices; - for (size_t i = 0; i < group.weights.size(); ++i) - weightIndices.push_back(group.startingCrossbarIndex + i); - - SmallVector xKerPos; - SmallVector yKerPos; - for (auto weight : group.weights) { - // Assert that the weight is an extract_slice operation. - auto extractSliceOp = weight.getDefiningOp(); - assert(extractSliceOp && "Weight is not an extract_slice operation."); - - // Get the filter x and y positions from the extract_slice operation. - auto offsets = extractSliceOp.getStaticOffsets(); - xKerPos.push_back(offsets[2]); - yKerPos.push_back(offsets[3]); - } - - ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices); - ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); - ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); - - Value result = rewriter.create( - conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); - - // Perform local reduction if necessary. - if (localPartialResults.find(group.outputTile) != localPartialResults.end()) { - - result = rewriter.create( - conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result); - } - - // Update the partial results map. - localPartialResults[group.outputTile] = result; - } - - // Add a yield operation to the block by concatenating the partial - // results. - SmallVector applyFiltersResults; - for (size_t i = 0; i < computeOutputType.size(); ++i) { - long outputTile; - - // Given an output tile index, find the corresponding output tile. - for (auto outputTileIndex : outputTileIndices) { - if (outputTileIndex.second == i) { - outputTile = outputTileIndex.first; - break; - } - } - - // Get that tile's partial result and add it to the list. - applyFiltersResults.push_back(localPartialResults[outputTile]); - } - - // Create the yield operation with the given results. - rewriter.create(conv.getLoc(), applyFiltersResults); - - // Update the global partial results map. - for (size_t i = 0; i < applyFiltersResults.size(); ++i) { - long outputTile; - - // Given an output tile index, find the corresponding output tile. - for (auto outputTileIndex : outputTileIndices) { - if (outputTileIndex.second == i) { - outputTile = outputTileIndex.first; - break; - } - } - - globalPartialResults[outputTile] = currentCompute.getResult(i); - } - - // Move the rewrite cursor out of the block. - rewriter.setInsertionPointAfter(currentCompute); - } - - // ------------------------------ // - // --- CONCATENATE THE OUTPUT --- // - // ------------------------------ // - - // Turn the values into a SmallVector. - SmallVector outputValues; - for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i) - outputValues.push_back(globalPartialResults[i]); - - // Assert that the number of output values is correct. - assert(outputValues.size() > 0 && "No output values were generated for the convolution."); - - // If the conv's user is a ReLU... - if (conv->hasOneUse()) { - Operation* user = *conv->getUsers().begin(); - if (auto relu = dyn_cast(user)) { - // ...then we can just replace the ReLU with the concatenation. - rewriter.replaceOp(relu, rewriter.create(conv.getLoc(), 1, outputValues)); - - // And erase the convolution. - rewriter.eraseOp(conv); - return success(); - } - } - - // Return the final output. - rewriter.replaceOp(conv, rewriter.create(conv.getLoc(), 1, outputValues)); - - return success(); - } -}; - -/** - * @brief Populate the tiling pattern for a convolution operation. - * - * @param patterns The pattern set to populate. - * @param ctx The MLIR context. - */ -void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp deleted file mode 100644 index 6dc34ab..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp +++ /dev/null @@ -1,365 +0,0 @@ -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -#include "Compiler/PimCompilerOptions.hpp" -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" -#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; -using namespace std; - -namespace onnx_mlir { - -struct ExperimentalGemmConversionPattern : public OpConversionPattern { - ExperimentalGemmConversionPattern(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - LogicalResult - matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - - // --------------------------------- // - // --- READ OPERATION PARAMETERS --- // - // --------------------------------- // - - // To get each crossbar's weights, we need to slice the weights tensor. - // - Along the input tiles. - // - Along the output tiles. - // - Along the filter x position. - // - Along the filter y position. - ShapedType inputType = cast(adaptor.getA().getType()); - ShapedType outputType = cast(gemmOp.getY().getType()); - ShapedType matrixType = cast(adaptor.getB().getType()); - - // TODO: Address bigger batches. - assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM."); - - // TODO: Address replication. - assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM."); - - // TODO: Address bias addition. - - assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size."); - - ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize); - ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize); - size_t kernelWidth = 1; - size_t kernelHeight = 1; - - // Assert that the kernel is square. - assert(kernelWidth == kernelHeight && "Only square kernels are supported."); - - // -------------------------------- // - // --- SLICE THE WEIGHTS TENSOR --- // - // -------------------------------- // - - // The core idea of this stage is classifying the weights by input and - // output tile. This is because we want the applyFilters operations to be - // tile agnostic, to keep the subsequent lowering stages as simple as - // possible. This data structure does this weight classification: - // - The outer map is indexed by input tile. - // - The inner map is indexed by output tile. - // - The SmallVector contains the weights for the filter. - map>> weightsGroups; - - // During all slicing operations within this stage, we'll use the same - // strides for all dimensions. - SmallVector slicingStrides(2, rewriter.getIndexAttr(1)); - - ldiv_t itc = inputTileCount; - ldiv_t otc = outputTileCount; - - // - Slicing along the input tiles. - // - Slicing along the output tiles. - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - long crossbarWidth = it == itc.quot ? itc.rem : crossbarSize; - for (long ot = 0; ot < otc.quot + (otc.rem > 0); ++ot) { - long crossbarHeight = ot == otc.quot ? otc.rem : crossbarSize; - - // The loop above also sets the crossbar's used width and height, - // checking if we're at the last crossbar and if it's incomplete. - - long outputTile = ot; - long inputTile = it; - - // Create the slicing sizes. - SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight), - /* 1 */ rewriter.getIndexAttr(crossbarWidth), - /* 2 */ /* rewriter.getIndexAttr(1), */ - /* 3 */ /* rewriter.getIndexAttr(1) */}; - - // - Slicing along the filter x position. - // - Slicing along the filter y position. - for (size_t filterX = 0; filterX < kernelWidth; ++filterX) { - for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { - - // Create the slicing offsets. - SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), - /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), - /* 2 */ /* rewriter.getIndexAttr(filterX), */ - /* 3 */ /* rewriter.getIndexAttr(filterY) */}; - - // Create the slice extraction operation. - auto extractSliceOp = rewriter.create( - gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides); - - // Add a note to the extractSliceOp, with the filterX and filterY. - weightsGroups[inputTile][outputTile].push_back(extractSliceOp); - } - } - } - } - - // TODO: Tree reduction for compute reduction should be implemented. - - // -------------------------------- // - // --- CREATE ALL COMPUTE UNITS --- // - // -------------------------------- // - - // Keep track of input slicing operations to avoid duplication across - // all compute units (global slices). - map globalSlices; - - // Keep track of all partial compute results. - map globalPartialResults; - - // Use a weight subdivider to extract groups of weights for each compute - // unit. We'll keep extracting groups until no more weights are left. - WeightSubdivider weightSubdivider(weightsGroups); - while (!weightSubdivider.isEmpty()) { - - // -------------------------------- // - // --- BEGIN A NEW COMPUTE UNIT --- // - // -------------------------------- // - - // Get the next group of weights for the compute unit. - SmallVector weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue()); - - SmallVector computeWeights; - SmallVector computeOperands; - - // ------------------------------ // - // --- SLICE THE INPUT TENSOR --- // - // ------------------------------ // - - // Note each tile's index in the compute unit arguments. - map inputTileIndices; - map outputTileIndices; - map reductionTileIndices; // Incoming partial results. - - // Iterate over all weights groups for this compute unit. - map localSlices; // WRT the current compute unit. - for (auto group : weightsGroups) { - for (Value weight : group.weights) - computeWeights.push_back(weight); - - // There might be multiple weight groups for the same input tile, so if - // we've already added the input tile, skip it. - if (localSlices.find(group.inputTile) != localSlices.end()) - continue; - - // We might have already sliced the input tensor for some other compute - // unit, so if we have, reuse the slicing operation without creating a - // new one. - if (globalSlices.find(group.inputTile) != globalSlices.end()) { - computeOperands.push_back(globalSlices[group.inputTile]); - localSlices[group.inputTile] = globalSlices[group.inputTile]; - continue; - } - - // Create the input tensor slicing offsets. - SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. - /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), - /* 2 */ /* rewriter.getIndexAttr(0), */ - /* 3 */ /* rewriter.getIndexAttr(0) */}; - - // Create the input tensor slicing sizes. - size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize; - SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */ - /* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */}; - - // Create the slice extraction operation. - auto extractSliceOp = rewriter.create( - gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides); - - computeOperands.push_back(extractSliceOp); - - // Update slicing maps. - globalSlices[group.inputTile] = extractSliceOp; - localSlices[group.inputTile] = extractSliceOp; - - // Update the input tile index. - inputTileIndices[group.inputTile] = computeOperands.size() - 1; - } - - // ------------------------------- // - // --- PREPARE THE OUTPUT TYPE --- // - // ------------------------------- // - - // Fill the compute output's type by looking at the output tiles. - SmallVector computeOutputType; - for (TaggedWeights group : weightsGroups) { - - // There might be multiple weight groups for the same output tile, so if - // we've already added the output tile, skip it. - if (outputTileIndices.find(group.outputTile) != outputTileIndices.end()) - continue; - - // Additionally, after adding the input slices as operands, also add any - // compatible partial results from previous compute units. - if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) { - computeOperands.push_back(globalPartialResults[group.outputTile]); - reductionTileIndices[group.outputTile] = computeOperands.size() - 1; - } - - // Define the output shape for this group. - long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize; - - // TODO: Address non-same padding. - SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. - /* 1 */ outputTileSize, - /* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed. - /* 3 */ /* GET_IMAGE_HEIGHT(outputType) */}; - - auto elementType = dyn_cast(gemmOp.getY().getType()).getElementType(); - - computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType)); - - outputTileIndices[group.outputTile] = computeOutputType.size() - 1; - } - - // ----------------------------- // - // --- FILL THE COMPUTE UNIT --- // - // ----------------------------- // - - // Create the compute unit. - spatial::SpatWeightedCompute currentCompute = rewriter.create( - gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands); - - // Create a new block for the compute unit and add the operands. - Block* block = rewriter.createBlock(¤tCompute.getRegion()); - rewriter.setInsertionPointToStart(block); - for (Value operand : computeOperands) - block->addArgument(operand.getType(), gemmOp->getLoc()); - - // Initialize a map of local partial results. - map localPartialResults; // WRT the current compute unit. - - // If we have any reduction tiles, add them to the local partial results. - for (auto reductionTileIndex : reductionTileIndices) - localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second); - - // Add all the applyFilters operations to the block. - for (TaggedWeights group : weightsGroups) { - - // Get the outputType for this group. - Type outputType = computeOutputType[outputTileIndices[group.outputTile]]; - - // Create an apply filters operation. - BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]); - - // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, - // ... As many weights as the size of group.weights. - SmallVector weightIndices; - for (size_t i = 0; i < group.weights.size(); ++i) - weightIndices.push_back(group.startingCrossbarIndex + i); - - SmallVector xKerPos; - SmallVector yKerPos; - for (auto weight : group.weights) { - // Assert that the weight is an extract_slice operation. - auto extractSliceOp = weight.getDefiningOp(); - assert(extractSliceOp && "Weight is not an extract_slice operation."); - - // Get the filter x and y positions from the extract_slice operation. - xKerPos.push_back(0); - yKerPos.push_back(0); - } - - ArrayAttr weightIndicesAttr = rewriter.getI64ArrayAttr(weightIndices); - ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); - ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); - - Value result = rewriter.create( - gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); - - // Perform local reduction if necessary. - if (localPartialResults.find(group.outputTile) != localPartialResults.end()) { - - result = rewriter.create( - gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result); - } - - // Update the partial results map. - localPartialResults[group.outputTile] = result; - } - - // Add a yield operation to the block by concatenating the partial - // results. - SmallVector applyFiltersResults; - for (size_t i = 0; i < computeOutputType.size(); ++i) { - long outputTile; - - // Given an output tile index, find the corresponding output tile. - for (auto outputTileIndex : outputTileIndices) { - if (outputTileIndex.second == i) { - outputTile = outputTileIndex.first; - break; - } - } - - // Get that tile's partial result and add it to the list. - applyFiltersResults.push_back(localPartialResults[outputTile]); - } - - // Create the yield operation with the given results. - rewriter.create(gemmOp.getLoc(), applyFiltersResults); - - // Update the global partial results map. - for (size_t i = 0; i < applyFiltersResults.size(); ++i) { - long outputTile; - - // Given an output tile index, find the corresponding output tile. - for (auto outputTileIndex : outputTileIndices) { - if (outputTileIndex.second == i) { - outputTile = outputTileIndex.first; - break; - } - } - - globalPartialResults[outputTile] = currentCompute.getResult(i); - } - - // Move the rewrite cursor out of the block. - rewriter.setInsertionPointAfter(currentCompute); - } - - // ------------------------------ // - // --- CONCATENATE THE OUTPUT --- // - // ------------------------------ // - - // Turn the values into a SmallVector. - SmallVector outputValues; - for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i) - outputValues.push_back(globalPartialResults[i]); - - // Assert that the number of output values is correct. - assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation."); - - // Return the final output. - rewriter.replaceOp(gemmOp, rewriter.create(gemmOp.getLoc(), 1, outputValues)); - - return success(); - } -}; - -void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index 1e13f8e..f4acf18 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -10,7 +10,6 @@ #include -#include "Gemm.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" @@ -20,6 +19,38 @@ using namespace mlir; namespace onnx_mlir { +namespace { + +constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; + +struct GemmToManyGemv : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const override; +}; + +struct GemvToSpatialCompute : OpConversionPattern { + GemvToSpatialCompute(MLIRContext* ctx) + : OpConversionPattern(ctx, 1) {} + + LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const override; + +private: + static Value resolveONNXExpOpFromUseChain(Value startValue); + + static LogicalResult softmaxReductionApplication(SmallVector& outputOpsAndResNums, + Value& softmaxChannel, + ConversionPatternRewriter& rewriter, + SpatialReducer& reducer, + ONNXGemmOp& gemmOp, + Location& loc); +}; + +} // namespace LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp deleted file mode 100644 index 2853674..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.hpp +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include "Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -namespace onnx_mlir { - -constexpr mlir::StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; - -struct GemmToManyGemv : mlir::OpConversionPattern { - GemmToManyGemv(mlir::MLIRContext* ctx) - : OpConversionPattern(ctx, 2) {} - - mlir::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp, - mlir::ONNXGemmOpAdaptor gemmOpAdaptor, - mlir::ConversionPatternRewriter& rewriter) const override; -}; - -struct GemvToSpatialCompute : mlir::OpConversionPattern { - GemvToSpatialCompute(mlir::MLIRContext* ctx) - : OpConversionPattern(ctx, 1) {} - - llvm::LogicalResult matchAndRewrite(mlir::ONNXGemmOp gemmOp, - mlir::ONNXGemmOpAdaptor gemmOpAdaptor, - mlir::ConversionPatternRewriter& rewriter) const override; - -private: - /** - * Resolves the ONNXExpOp from the use chain of the given start value. - * - * This function traverses the use chain of the start value until it finds an - * ONNXExpOp. It returns the value of the ONNXExpOp. - * - * @param startValue The starting value of the use chain. - * @return The value of the ONNXExpOp found in the use chain. - */ - static mlir::Value resolveONNXExpOpFromUseChain(mlir::Value startValue); - - // Softmax is a special case, as it requires another reduction after the - // first one. In the cores, `applyReducePattern` already applied - // f(x) = exp(x) to each tile. This mean that now we just need to - // reduce-sum these tiles, and then divide each tile by the reduced sum, - // which is propagated back to the cores via a broadcast channel. - static llvm::LogicalResult softmaxReductionApplication(llvm::SmallVector& outputOpsAndResNums, - Value& softmaxChannel, - ConversionPatternRewriter& rewriter, - SpatialReducer& reducer, - ONNXGemmOp& gemmOp, - Location& loc); -}; - -void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp deleted file mode 100644 index eae57f5..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp +++ /dev/null @@ -1,300 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include - -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -template -bool hasPostProcessExperimentalPoolingWindow() { - return false; -} - -template <> -bool hasPostProcessExperimentalPoolingWindow() { - return true; -} - -template -Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter, - Location loc, - PoolOp poolOp, - Value valueToDivide, - size_t krn_size, - size_t tilesSkippedByPadding) { - return nullptr; -} - -template <> -Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter, - Location loc, - ONNXAveragePoolOp poolOp, - Value valueToDivide, - size_t krn_size, - size_t tilesSkippedByPadding) { - bool countIncludePad = poolOp.getCountIncludePad() == 1; - - size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; - - RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type()); - - // Put a spat.const before the computeOp, and use its value. We do this to be - // compatible with the current code generation, which assumes constant to be - // loaded in global memory, which is allocated by adding a spat.const OP - // directly under func.func (i.e. alongside ComputeOps) - auto computeOp = cast(valueToDivide.getDefiningOp()->getParentOp()); - rewriter.setInsertionPoint(computeOp); - auto divisorValue = rewriter.create(loc, - scalarTensor, - rewriter.getI64IntegerAttr(divisorNumber), - /* should_allocate = */ rewriter.getBoolAttr(true)); - - rewriter.setInsertionPointAfterValue(valueToDivide); - return rewriter.create(loc, valueToDivide.getType(), valueToDivide, divisorValue); -} - -template -Value reduceInputTiles(SmallVector& inputTiles, ConversionPatternRewriter& rewriter) { - if (inputTiles.size() == 1) - return inputTiles[0]; - - if (inputTiles.size() == 2) { - return rewriter.create( - inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]); - } - - SmallVector left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2); - SmallVector right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end()); - - Value leftReduced = reduceInputTiles(left, rewriter); - Value rightReduced = reduceInputTiles(right, rewriter); - - return rewriter.create(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced); -} - -template -struct ExperimentalPoolingBaseConverter : public OpConversionPattern { - ExperimentalPoolingBaseConverter(MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - Value X = adaptor.getX(); - ShapedType xShape = mlir::cast(X.getType()); - Value Y = poolOp.getResult(); - ShapedType yShape = mlir::cast(Y.getType()); - - size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h; - unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y); - unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); - unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); - - if (adaptor.getAutoPad() != "NOTSET") - return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated."); - - size_t pad_x, pad_y; - auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) - return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); - - Location loc = poolOp.getLoc(); - - size_t input_h = GET_IMAGE_HEIGHT(xShape); - size_t input_w = GET_IMAGE_WIDTH(xShape); - size_t output_h = GET_IMAGE_HEIGHT(yShape); - size_t output_w = GET_IMAGE_WIDTH(yShape); - - ldiv_t tileCount = std::div(GET_IMAGE_CHANNEL(xShape), crossbarSize); - - // Assert that the input is a tensor.ConcatOp. - auto concat = X.getDefiningOp(); - if (!concat) - return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp"); - - // Create a [channel_tile][x][y] array to store the input tiles. - std::map>> inputTiles; - - // For each argument of the tensor.ConcatOp, resolve the input tiles. - for (size_t y = 0; y < input_h; ++y) { - for (size_t x = 0; x < input_w; ++x) { - for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) { - size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; - - SmallVector strides(4, rewriter.getIndexAttr(1)); - SmallVector offsets = {/* 0 */ rewriter.getIndexAttr(0), - /* 1 */ rewriter.getIndexAttr(0), - /* 2 */ rewriter.getIndexAttr(x), - /* 3 */ rewriter.getIndexAttr(y)}; - SmallVector sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ rewriter.getIndexAttr(1), - /* 3 */ rewriter.getIndexAttr(1)}; - - // Get the concat's operand that we want to slice. - Value concatInput = concat.getOperand(it); - Value slicedTile = rewriter.create(loc, concatInput, offsets, sizes, strides); - - inputTiles[it][x][y] = slicedTile; - } - } - } - - // Prepare the shape of the compute's output. - ldiv_t itc = tileCount; - SmallVector outputTileTypes; - for (size_t y = 0; y < output_h; ++y) { - for (size_t x = 0; x < output_w; ++x) { - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. - /* 1 */ - cast(inputTiles[it][0][0].getType()).getShape()[1], - /* 2 */ 1, - /* 3 */ 1}; - - auto elementType = dyn_cast(xShape).getElementType(); - - outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType)); - } - } - } - - // Create a plain value list of the input tiles. - SmallVector inputTilesList; - for (size_t y = 0; y < input_h; ++y) { - for (size_t x = 0; x < input_w; ++x) - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) - inputTilesList.push_back(inputTiles[it][y][x]); - } - - // Create a single compute to calculate the output. - auto computeOp = - rewriter.create(loc, outputTileTypes, SmallVector(), inputTilesList); - - // Create a new block for the compute unit and add the operands. - Block* block = rewriter.createBlock(&computeOp.getRegion()); - - // Fill the block arguments and keep a reference to them. - std::map>> inputTilesArgs; - for (size_t y = 0; y < input_h; ++y) { - for (size_t x = 0; x < input_w; ++x) { - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it; - inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc); - } - } - } - - // Begin writing in the block. - rewriter.setInsertionPointToStart(block); - - // Go through all pooling blocks. - SmallVector outputTiles; - for (size_t y = 0; y < output_h; ++y) { - for (size_t x = 0; x < output_w; ++x) { - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - size_t start_x = x * stride_x; - size_t start_y = y * stride_y; - size_t end_x = std::min(start_x + krn_w, input_w); - size_t end_y = std::min(start_y + krn_h, input_h); - - SmallVector inputTilesToReduce; - for (size_t ky = start_y; ky < end_y; ++ky) - for (size_t kx = start_x; kx < end_x; ++kx) - inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]); - - auto reduceResult = reduceInputTiles(inputTilesToReduce, rewriter); - - // If the reduce op is add, we need to divide the result by the - // number of elements in the pooling window. - if (hasPostProcessExperimentalPoolingWindow()) { - // Add a spat.const before the computeOp. - rewriter.setInsertionPoint(computeOp); - auto divisorValue = - rewriter.create(loc, - RankedTensorType::get({1}, rewriter.getF32Type()), - rewriter.getI64IntegerAttr(krn_w * krn_h), - rewriter.getBoolAttr(true)); - - rewriter.setInsertionPointAfter(reduceResult.getDefiningOp()); - reduceResult = - rewriter.create(loc, reduceResult.getType(), reduceResult, divisorValue); - } - outputTiles.push_back(reduceResult); - } - } - } - - // Create a YieldOp to return the output tiles. - rewriter.create(loc, outputTiles); - - // Set the rewrite cursor right after the computeOp. - rewriter.setInsertionPointAfter(computeOp); - - std::map>> computeOutput; - for (size_t y = 0; y < output_h; ++y) { - for (size_t x = 0; x < output_w; ++x) { - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it; - computeOutput[it][y][x] = computeOp.getResult(tileIndex); - } - } - } - - // We'll now create spat.img.concat ops to concatenate the output tiles. - SmallVector outputTilesList; - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - SmallVector imgConcatTiles; - for (size_t y = 0; y < output_h; ++y) - for (size_t x = 0; x < output_w; ++x) - imgConcatTiles.push_back(computeOutput[it][y][x]); - - size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; - - SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. - /* 1 */ (long) tilingSize, - /* 2 */ (long) output_w, - /* 3 */ (long) output_h}; - - auto elementType = dyn_cast(xShape).getElementType(); - - outputTilesList.push_back(rewriter.create( - loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles)); - } - - // Create a new tensor.ConcatOp to concatenate the output tiles. - Value outputTensor = rewriter.create(loc, 1, outputTilesList); - - rewriter.replaceOp(poolOp, outputTensor); - - return success(); - } -}; - -void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert< - ExperimentalPoolingBaseConverter>(ctx); - patterns.insert>( - ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp index e6995bb..c440d80 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp @@ -26,8 +26,6 @@ using namespace mlir; namespace onnx_mlir { -llvm::SmallPtrSet oldComputeOpsReplaced; - Value applyReducePatternNew(SmallVector& valuesToReduce, ConversionPatternRewriter& rewriter, std::function reduce, @@ -225,12 +223,12 @@ struct PoolingBaseConverter : public OpConversionPattern { Location loc = poolOp.getLoc(); - size_t input_h = GET_IMAGE_HEIGHT(xShape); - size_t input_w = GET_IMAGE_WIDTH(xShape); - size_t output_h = GET_IMAGE_HEIGHT(yShape); - size_t output_w = GET_IMAGE_WIDTH(yShape); - size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); - size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; + size_t input_h = getImageHeight(xShape); + size_t input_w = getImageWidth(xShape); + size_t output_h = getImageHeight(yShape); + size_t output_w = getImageWidth(yShape); + size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue()); + size_t channelTileRest = getImageChannel(xShape) % crossbarSize; // 1: Tile the input tensor // Input tiles need to be indexed by: diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td index 529d579..a4e0bc9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td @@ -13,9 +13,7 @@ def onnxToArithConstantOp : Pat< (Arith_ConstantOp $value) >; -//===----------------------------------------------------------------------===// // ONNXMatMulOp to ONNXGemmOp patterns -//===----------------------------------------------------------------------===// def matMulAddToGemmPattern : Pat< (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), @@ -39,9 +37,7 @@ def matMulToGemmPattern : Pat< ) >; -//===----------------------------------------------------------------------===// // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern -//===----------------------------------------------------------------------===// // This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single // ONNXConvOp with a bias. @@ -55,9 +51,7 @@ def convAddToConvWithBiasPatternRight : Pat< (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) >; -//===----------------------------------------------------------------------===// // Operation to ignore (i.e. remove) -//===----------------------------------------------------------------------===// def replaceWithOperationOfValue : NativeCodeCall<"$0">; diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp index bd1ecac..6d8a849 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.cpp @@ -180,10 +180,10 @@ void tileImageTensorByChannel(Value imageTensor, ConversionPatternRewriter& rewriter) { ShapedType imageShape = mlir::cast(imageTensor.getType()); - size_t input_h = GET_IMAGE_HEIGHT(imageShape); - size_t input_w = GET_IMAGE_WIDTH(imageShape); - size_t tileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(imageShape), tileSize); - size_t tileRest = GET_IMAGE_CHANNEL(imageShape) % tileSize; + size_t input_h = getImageHeight(imageShape); + size_t input_w = getImageWidth(imageShape); + size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize); + size_t tileRest = getImageChannel(imageShape) % tileSize; SmallVector strides(4, rewriter.getIndexAttr(1)); SmallVector offsets(4, rewriter.getIndexAttr(0)); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp index 52ec881..2676698 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp @@ -9,24 +9,55 @@ #include "llvm/Support/LogicalResult.h" +#include +#include +#include +#include +#include + #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #define DEFINE_MAP_OP(opname) opname, -#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2) -#define GET_IMAGE_HEIGHT(shapedType) shapedType.getDimSize(3) -#define GET_IMAGE_CHANNEL(shapedType) shapedType.getDimSize(1) -#define GET_IMAGE_N(shapedType) shapedType.getDimSize(0) -#define GET_KERNEL_WIDTH(shapedType) shapedType.getDimSize(2) -#define GET_KERNEL_HEIGHT(shapedType) shapedType.getDimSize(3) -#define GET_FILTER_COUNT(shapedType) shapedType.getDimSize(0) - -using namespace mlir; - namespace onnx_mlir { -const StringRef REPLICATION_ATTR_NAME = "replication_factor"; +template +inline auto getImageWidth(const ShapedType& shapedType) { + return shapedType.getDimSize(2); +} + +template +inline auto getImageHeight(const ShapedType& shapedType) { + return shapedType.getDimSize(3); +} + +template +inline auto getImageChannel(const ShapedType& shapedType) { + return shapedType.getDimSize(1); +} + +template +inline auto getImageN(const ShapedType& shapedType) { + return shapedType.getDimSize(0); +} + +template +inline auto getKernelWidth(const ShapedType& shapedType) { + return shapedType.getDimSize(2); +} + +template +inline auto getKernelHeight(const ShapedType& shapedType) { + return shapedType.getDimSize(3); +} + +template +inline auto getFilterCount(const ShapedType& shapedType) { + return shapedType.getDimSize(0); +} + +inline constexpr mlir::StringRef REPLICATION_ATTR_NAME = "replication_factor"; using HSliceId = size_t; using CoreId = size_t; @@ -58,51 +89,64 @@ constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { } template -bool isVectorShape(const ArrayRef shape) { +bool isVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); } template -bool isMatrixShape(const ArrayRef shape) { +bool isMatrixShape(mlir::ArrayRef shape) { return shape.size() == 2; } template -bool isHVectorShape(const ArrayRef shape) { +bool isHVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[0] == 1; } template -bool isVVectorShape(const ArrayRef shape) { +bool isVVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[1] == 1; } template -T getVectorLength(const ArrayRef shape) { +T getVectorLength(mlir::ArrayRef shape) { assert(isVectorShape(shape)); return shape[0] != 1 ? shape[0] : shape[1]; } -inline auto getTensorShape(const Value tensor) { return cast(tensor.getType()).getShape(); } +inline auto getTensorShape(mlir::Value tensor) { + return mlir::cast(tensor.getType()).getShape(); +} -SmallVector sliceTensor( - const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); +llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, + size_t axis, + int64_t sliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); -SmallVector -sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc); +llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, + int64_t sliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); -DenseMap> -sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc); +llvm::DenseMap> sliceVectorPerCrossbarPerCore( + const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); -DenseMap>> tileMatrix( - Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc); +llvm::DenseMap>> +tileMatrix(mlir::Value& matrixToTile, + int64_t hSliceSize, + int64_t vSliceSize, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location& loc); -tensor::SplatOp -broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc); +mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, + int64_t length, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); -Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter); +mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); -Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input); +mlir::Value createMapOperation(mlir::PatternRewriter& rewriter, MapOperations mapOp, const mlir::Value& input); /** * Unpacks an optional pair vector into two size_t values. @@ -126,7 +170,8 @@ void unpackOptionalPairVector(std::optional valuesArray, size_t * * @return llvm::Optional The error message if the pads are invalid */ -std::optional unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y); +std::optional +unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y); /** * Tiles the image tensor by channel. @@ -140,10 +185,10 @@ std::optional unpackOptionalPadsVector(std::optional val * @param tileSize The size of each tile. * @param rewriter The ConversionPatternRewriter used for creating operations. */ -void tileImageTensorByChannel(Value imageTensor, - SmallVector>>& tiles, +void tileImageTensorByChannel(mlir::Value imageTensor, + llvm::SmallVector>>& tiles, size_t tileSize, - ConversionPatternRewriter& rewriter); + mlir::ConversionPatternRewriter& rewriter); /** * Creates an ImgConcatOp based on the given tiles. @@ -159,10 +204,10 @@ void tileImageTensorByChannel(Value imageTensor, * * @return The created ImgConcatOp. */ -Value createImgConcatOp(SmallVector>>& outputTiles, - ConversionPatternRewriter& rewriter, - Location& loc, - Type outputType); +mlir::Value createImgConcatOp(llvm::SmallVector>>& outputTiles, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location& loc, + mlir::Type outputType); /** * @brief Verifies if the given input coordinates and padding values are within @@ -177,7 +222,7 @@ Value createImgConcatOp(SmallVector>>& outputTile * @return LogicalResult Returns success if the coordinates and padding are * within bounds, failure otherwise. */ -LogicalResult +mlir::LogicalResult verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y); /** @@ -207,13 +252,14 @@ verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, * @return std::optional An error message if the input tensor could * not be resolved into tiles. */ -std::optional resolveImgInputTiles(Value wholeInputTensor, - SmallVector>>& inputTiles, - size_t channelTileCount, - size_t channelTileRest, - size_t input_w, - size_t input_h, - mlir::ConversionPatternRewriter& rewriter); +std::optional +resolveImgInputTiles(mlir::Value wholeInputTensor, + llvm::SmallVector>>& inputTiles, + size_t channelTileCount, + size_t channelTileRest, + size_t input_w, + size_t input_h, + mlir::ConversionPatternRewriter& rewriter); /** * Computes the boundaries of an image kernel application. @@ -258,6 +304,6 @@ void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcom * @return The index of the result of the operation that produces the specified * value. */ -int getResultIndex(Operation* op, Value v); +int getResultIndex(mlir::Operation* op, mlir::Value v); }; // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index af75fd9..17ee2f1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Pass/Pass.h" @@ -10,19 +11,39 @@ #include "Common/PIMCommon.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" -#include "Math/Conv.hpp" -#include "ONNXToSpatialPass.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerOptions.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { -namespace spatial { +bool haveSameStaticShape(Value lhs, Value rhs); + +namespace { + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" + +struct ONNXToSpatialPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) + StringRef getArgument() const override { return "convert-onnx-to-spatial"; } + StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } + + ONNXToSpatialPass() = default; + ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} + + void runOnOperation() override; + +private: + void annotateWeightsConstants(func::FuncOp funcOp) const; +}; + +} // namespace void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); @@ -40,15 +61,19 @@ void ONNXToSpatialPass::runOnOperation() { llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; IRRewriter rewriter(moduleOp); - func::FuncOp funcOp = *moduleOp.getOps().begin(); - if (annotateReplication(funcOp, rewriter).failed()) { + auto entryFunc = getPimEntryFunc(moduleOp); + if (failed(entryFunc)) { + signalPassFailure(); + return; + } + if (annotateReplication(*entryFunc, rewriter).failed()) { llvm::dbgs() << "Failed during annotation for replication analysis\n"; signalPassFailure(); return; } ConversionTarget target(*ctx); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -62,16 +87,9 @@ void ONNXToSpatialPass::runOnOperation() { RewritePatternSet patterns(ctx); patterns.add(ctx); - if (useExperimentalConvImpl) { - populateExperimentalTilingConvOpPattern(patterns, ctx); - populateExperimentalPoolingTilingPattern(patterns, ctx); - populateGemmToConvConversionPattern(patterns, ctx); - } - else { - populateTilingConvOpPattern(patterns, ctx); - populatePoolingTilingPattern(patterns, ctx); - populateOnnxGemmOpPatterns(patterns, ctx); - } + populateConvOpPatterns(patterns, ctx); + populatePoolingTilingPattern(patterns, ctx); + populateOnnxGemmOpPatterns(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx); @@ -84,8 +102,8 @@ void ONNXToSpatialPass::runOnOperation() { // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; - for (auto& op : funcOp.getFunctionBody().front().getOperations()) - if (isa(op)) + for (auto& op : entryFunc->getFunctionBody().front().getOperations()) + if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { @@ -102,22 +120,21 @@ void ONNXToSpatialPass::runOnOperation() { if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns)))) llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; - annotateWeightsConstants(funcOp); + annotateWeightsConstants(*entryFunc); // Dump to file for debug dumpModule(moduleOp, "spatial"); } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { - MLIRContext* ctx = funcOp.getContext(); funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = - llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); + llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) - constantOp->setAttr("weightAlways", UnitAttr::get(ctx)); + markWeightAlways(constantOp); }); } -} // namespace spatial +std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp deleted file mode 100644 index 8fc2c82..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.hpp +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include "mlir/Pass/Pass.h" - -#include "src/Dialect/ONNX/ONNXOps.hpp" - -namespace onnx_mlir { - -using namespace mlir; -extern bool haveSameStaticShape(Value lhs, Value rhs); - -namespace spatial { - -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" - -struct ONNXToSpatialPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) - StringRef getArgument() const override { return "convert-onnx-to-spatial"; } - StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } - - ONNXToSpatialPass() = default; - ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} - - void runOnOperation() override; - -private: - void annotateWeightsConstants(func::FuncOp funcOp) const; -}; - -} // namespace spatial - -std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp index 416f096..1a1a3f6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -1,27 +1,20 @@ #pragma once -#include "mlir/IR/PatternMatch.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/DialectConversion.h" namespace onnx_mlir { -void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - -void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -// Experimental patterns. -void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp index ec7a874..6fb9bf7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp @@ -10,7 +10,7 @@ using namespace mlir; namespace onnx_mlir { template -struct RemoveUnusedHelperOps : public OpRewritePattern { +struct RemoveUnusedHelperOps : OpRewritePattern { RemoveUnusedHelperOps(MLIRContext* ctx) : OpRewritePattern(ctx) {} diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp index a9b0c70..c9092b1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp @@ -49,11 +49,11 @@ LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& r ShapedType xShape = mlir::cast(X.getType()); ShapedType wShape = mlir::cast(W.getType()); - size_t input_w = GET_IMAGE_WIDTH(xShape); - size_t krn_h = GET_KERNEL_HEIGHT(wShape); - size_t krn_w = GET_KERNEL_WIDTH(wShape); + size_t input_w = getImageWidth(xShape); + size_t krn_h = getKernelHeight(wShape); + size_t krn_w = getKernelWidth(wShape); - size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t inputTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue()); size_t outputTileCount = ceilIntegerDivide(wShape.getDimSize(0), crossbarSize.getValue()); auto neededXbars = krn_h * krn_w * inputTileCount * outputTileCount; diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp index faed00d..3187300 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp @@ -15,21 +15,21 @@ namespace onnx_mlir { -llvm::SmallPtrSet onnx_mlir::SpatialReducer::oldComputeOpsReplaced; +llvm::SmallPtrSet onnx_mlir::SpatialReducer::oldComputeOpsReplaced; ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum, - std::function processFun, - ConversionPatternRewriter& rewriter) { + std::function processFun, + mlir::ConversionPatternRewriter& rewriter) { assert(processFun); auto computeOp = GET_COMP(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum); - spatial::SpatYieldOp yieldOp = cast(computeOp.getBody().front().getTerminator()); + spatial::SpatYieldOp yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); - Value result = yieldOp->getOperand(resultNum); + mlir::Value result = yieldOp->getOperand(resultNum); rewriter.setInsertionPointAfterValue(result); - Value processedResult = processFun(result); + mlir::Value processedResult = processFun(result); if (processedResult == result) { // Sometimes we want processedResult to return the same value but do // something else with it (e.g. in softmax we want to broadcast the value @@ -42,10 +42,11 @@ ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum return yieldOp.getNumOperands() - 1; } -OpAndResNum SpatialReducer::applyReducePattern(SmallVector& computeOpsAndResNum, - std::function reduce, - std::function preprocess, - std::function postprocess) { +OpAndResNum +SpatialReducer::applyReducePattern(llvm::SmallVector& computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess) { if (preprocess) for (auto& computeOpAndResNum : computeOpsAndResNum) @@ -55,18 +56,18 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector& co // computeOp. In this case, we need to apply the reduction within-computef // Keep a map between a computeOp and the last Value for this reduction - std::unordered_map lastValueForCompute; + std::unordered_map lastValueForCompute; for (auto& computeOpAndResNum : computeOpsAndResNum) { auto computeOp = GET_COMP(computeOpAndResNum); - auto yieldOp = cast(computeOp.getBody().front().getTerminator()); - Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); + auto yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); + mlir::Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); auto it = lastValueForCompute.find(computeOp.getOperation()); if (it != lastValueForCompute.end()) { // If we have already seen this computeOp, apply the reduction // within-compute - Value lastWithinComputeValue = it->second; + mlir::Value lastWithinComputeValue = it->second; assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp()); @@ -85,12 +86,12 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector& co computeOpsAndResNum.clear(); computeOpsAndResNum.reserve(lastValueForCompute.size()); for (auto& entry : lastValueForCompute) { - auto computeOp = cast(entry.first); + auto computeOp = mlir::cast(entry.first); auto valueWithinCompute = entry.second; // We check if `valueWithinCompute` is already used by the yieldOp, in that // case no need to add it - auto yieldOp = cast(computeOp.getBody().front().getTerminator()); + auto yieldOp = mlir::cast(computeOp.getBody().front().getTerminator()); bool yieldOpUseFound = false; for (auto& use : valueWithinCompute.getUses()) { if (use.getOwner() == yieldOp.getOperation()) { @@ -110,7 +111,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector& co computeOpsAndResNum.push_back({computeOp, resultNum}); } - Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); + mlir::Location loc = GET_COMP(computeOpsAndResNum[0])->getLoc(); // Recursive algorithm to reduce the inputs to a single one: // - Take two inputs at a time, and reduce them into a single one, updating @@ -118,7 +119,7 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector& co // - Repeat until there is only one input left. llvm::OwningArrayRef computeOpsRef(computeOpsAndResNum); while (computeOpsRef.size() > 1) { - SmallVector nextComputeOps; + llvm::SmallVector nextComputeOps; nextComputeOps.reserve(computeOpsRef.size() / 2); for (size_t i = 0; i < computeOpsRef.size() - 1; i += 2) { auto [firstCompute, firstResultNum] = computeOpsRef[i]; @@ -135,23 +136,23 @@ OpAndResNum SpatialReducer::applyReducePattern(SmallVector& co // the number of results) // See below `reducerChanges.push_back` and `finalizeReduceUpdates` - auto yieldOpFirstCompute = cast(firstCompute.getBody().front().getTerminator()); + auto yieldOpFirstCompute = mlir::cast(firstCompute.getBody().front().getTerminator()); // Add a new operand to the block of the second computeOp - Block& secondBlock = secondCompute.getBody().front(); - Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); + mlir::Block& secondBlock = secondCompute.getBody().front(); + mlir::Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); auto secondComputeWeightsNum = - secondCompute->getAttrOfType(secondCompute.getOperandSegmentSizesAttrName())[0]; + secondCompute->getAttrOfType(secondCompute.getOperandSegmentSizesAttrName())[0]; auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1; // Take the "former-result" from the second computeOp - spatial::SpatYieldOp secondYield = cast(secondBlock.getTerminator()); - Value formerRes2 = secondYield.getOperand(secondResultNum); + spatial::SpatYieldOp secondYield = mlir::cast(secondBlock.getTerminator()); + mlir::Value formerRes2 = secondYield.getOperand(secondResultNum); // Apply reduction operation rewriter.setInsertionPoint(secondYield); - Value reduced = reduce(formerRes2, formerRes1); + mlir::Value reduced = reduce(formerRes2, formerRes1); // Unfortunately, it is not possible to update the result in place, // because we may have already referenced it by @@ -219,7 +220,7 @@ void SpatialReducer::finalizeReduceUpdates() { // `opToReplacedCompute` auto toComputeOp = opToReplacedCompute[toOp]; if (!toComputeOp) - toComputeOp = cast(toOp); + toComputeOp = mlir::cast(toOp); assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!"); @@ -234,31 +235,31 @@ void SpatialReducer::finalizeReduceUpdates() { } } -Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { +mlir::Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates."); - Operation* opToCast; + mlir::Operation* opToCast; auto it = opToReplacedCompute.find(opAndResNum.first); if (it != opToReplacedCompute.end()) opToCast = it->second; else opToCast = opAndResNum.first; - auto computeOp = cast(opToCast); + auto computeOp = mlir::cast(opToCast); return computeOp.getResult(opAndResNum.second); } -void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { +void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { // If we have already replaced the fromOp, we do not need to do it again return; } - auto oldComputeOp = cast(computeOp); + auto oldComputeOp = mlir::cast(computeOp); auto oldComputeOpNum = oldComputeOp->getNumOperands(); - auto yieldOp = cast(oldComputeOp.getBody().front().getTerminator()); + auto yieldOp = mlir::cast(oldComputeOp.getBody().front().getTerminator()); if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { // No result was added, just add itself to the map @@ -283,8 +284,8 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { // Since we replaced the old ComputeOp with a new one, we need to replace // all its results' uses for (size_t i = 0; i < oldComputeOp.getNumResults(); i++) { - Value oldResult = oldComputeOp.getResult(i); - Value newResult = newComputeOp.getResult(i); + mlir::Value oldResult = oldComputeOp.getResult(i); + mlir::Value newResult = newComputeOp.getResult(i); // Replace the uses, except the uses of the compute ops which got deleted // previously @@ -298,9 +299,10 @@ void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { rewriter.eraseOp(oldComputeOp); } -Value SpatialReducer::createImgConcatOp(SmallVector>>& outputTiles, - Location& loc, - Type outputType) { +mlir::Value +SpatialReducer::createImgConcatOp(llvm::SmallVector>>& outputTiles, + mlir::Location& loc, + mlir::Type outputType) { assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates."); @@ -309,8 +311,8 @@ Value SpatialReducer::createImgConcatOp(SmallVector>> remappedOutputTiles( - tilesCount, SmallVector>(width, SmallVector(height))); + llvm::SmallVector>> remappedOutputTiles( + tilesCount, llvm::SmallVector>(width, llvm::SmallVector(height))); for (size_t t = 0; t < tilesCount; t++) for (size_t x = 0; x < width; x++) @@ -320,16 +322,16 @@ Value SpatialReducer::createImgConcatOp(SmallVector& computeOps, - ConversionPatternRewriter& rewriter, - Value biasTile, +OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector& computeOps, + mlir::ConversionPatternRewriter& rewriter, + mlir::Value biasTile, MapOperations mapOp) { - std::function postprocessing = nullptr; + std::function postprocessing = nullptr; if (mapOp != MapOperations::None) { - postprocessing = [&](const Value a) { - Value mapOperand = a; + postprocessing = [&](const mlir::Value a) { + mlir::Value mapOperand = a; if (biasTile) mapOperand = rewriter.create(a.getLoc(), a.getType(), a, biasTile); return createMapOperation(rewriter, mapOp, mapOperand); @@ -338,7 +340,7 @@ OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector& return this->applyReducePattern( computeOps, - [&](Value a, Value b) { return rewriter.create(a.getLoc(), a.getType(), a, b); }, + [&](mlir::Value a, mlir::Value b) { return rewriter.create(a.getLoc(), a.getType(), a, b); }, /* preprocess = */ nullptr, postprocessing); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp index f724691..6739d48 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp @@ -3,6 +3,10 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" +#include +#include +#include + #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -13,28 +17,28 @@ using ResNum = unsigned int; using ComputeAndResNum = std::pair; struct SpatialReducerChange { - Operation* fromOp; + mlir::Operation* fromOp; unsigned int fromOpResNum; - Operation* toOp; + mlir::Operation* toOp; unsigned int toOpOperandNum; }; -using OpAndResNum = std::pair; +using OpAndResNum = std::pair; class SpatialReducer { public: - SpatialReducer(ConversionPatternRewriter& rewriter) + SpatialReducer(mlir::ConversionPatternRewriter& rewriter) : rewriter(rewriter) {} - OpAndResNum applyReducePattern(SmallVector& computeOpsAndResNum, - std::function reduce, - std::function preprocess, - std::function postprocess); + OpAndResNum applyReducePattern(llvm::SmallVector& computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess); - OpAndResNum applyAddMapReduction(SmallVector& computeOps, - ConversionPatternRewriter& rewriter, - Value biasTile, + OpAndResNum applyAddMapReduction(llvm::SmallVector& computeOps, + mlir::ConversionPatternRewriter& rewriter, + mlir::Value biasTile, MapOperations mapOp); void finalizeReduceUpdates(); @@ -44,17 +48,17 @@ public: finalizeReduceUpdates(); } - Value createImgConcatOp(llvm::SmallVector>>& outputTiles, - Location& loc, - Type outputType); + mlir::Value createImgConcatOp(llvm::SmallVector>>& outputTiles, + mlir::Location& loc, + mlir::Type outputType); - Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); + mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); private: [[nodiscard("computeOp result number gets updated")]] ResNum applyResultProcessing(ComputeAndResNum computeOpAndResNum, - std::function processFun, - ConversionPatternRewriter& rewriter); + std::function processFun, + mlir::ConversionPatternRewriter& rewriter); /** * @brief Update the results of a ComputeOp. @@ -66,19 +70,19 @@ private: * * @param computeOp The ComputeOp to update the results of. */ - void updateResultsOfCompute(Operation* computeOp); + void updateResultsOfCompute(mlir::Operation* computeOp); - ConversionPatternRewriter& rewriter; + mlir::ConversionPatternRewriter& rewriter; bool reducesFinalized = false; // List of changes to be applied after the reduction is finalized - SmallVector reducerChanges; + llvm::SmallVector reducerChanges; // List of computeOps that need to be replaced with new results - SmallVector computeOpNeedingResUpdate; + llvm::SmallVector computeOpNeedingResUpdate; - std::unordered_map opToReplacedCompute; + std::unordered_map opToReplacedCompute; - static llvm::SmallPtrSet oldComputeOpsReplaced; + static llvm::SmallPtrSet oldComputeOpsReplaced; }; } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp index e2466ec..6affb23 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp @@ -4,7 +4,7 @@ namespace onnx_mlir { -WeightSubdivider::WeightSubdivider(map>> weights) +WeightSubdivider::WeightSubdivider(std::map>> weights) : weights(std::move(weights)) {} bool WeightSubdivider::isEmpty() const { return weights.empty(); } @@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) { assert(!weights.empty() && "No weights to extract."); auto it = weights.begin(); - SmallVector& values = it->second.begin()->second; + llvm::SmallVector& values = it->second.begin()->second; long inputTile = it->first; long outputTile = it->second.begin()->first; @@ -21,7 +21,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) { size_t n = std::min(amount, values.size()); crossbarsUsed += n; - SmallVector result; + llvm::SmallVector result; result.assign(values.begin(), values.begin() + n); if (n < values.size()) { @@ -36,9 +36,9 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) { return {inputTile, outputTile, crossbarsUsed - n, result}; } -SmallVector WeightSubdivider::popGroups(size_t n) { +llvm::SmallVector WeightSubdivider::popGroups(size_t n) { crossbarsUsed = 0; - SmallVector result; + llvm::SmallVector result; size_t remaining = n; while (remaining > 0 && !weights.empty()) { diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp index 7c71986..eaa8320 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp @@ -4,11 +4,9 @@ #include "llvm/ADT/SmallVector.h" +#include #include -using namespace mlir; -using namespace std; - namespace onnx_mlir { /** @@ -19,7 +17,7 @@ struct TaggedWeights { long inputTile; long outputTile; size_t startingCrossbarIndex; - SmallVector weights; + llvm::SmallVector weights; }; /** @@ -33,16 +31,16 @@ struct TaggedWeights { */ class WeightSubdivider { private: - map>> weights; + std::map>> weights; size_t crossbarsUsed = 0; TaggedWeights popGroup(size_t amount); public: - WeightSubdivider(map>> weights); + WeightSubdivider(std::map>> weights); bool isEmpty() const; - SmallVector popGroups(size_t n); + llvm::SmallVector popGroups(size_t n); }; } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index 0cab2b9..d331444 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -10,6 +10,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Format.h" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -199,12 +200,12 @@ private: void SpatialToGraphvizPass::runOnOperation() { ModuleOp module = getOperation(); - // Get the first OP, must be a FuncOp - func::FuncOp func = *module.getOps().begin(); - if (!func) { - module->emitError("No FuncOp found in the begin of module"); + auto entryFunc = getPimEntryFunc(module); + if (failed(entryFunc)) { signalPassFailure(); + return; } + func::FuncOp func = *entryFunc; os << "digraph G {\n" << "\tnode [style=filled,color=white];\n"; diff --git a/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt b/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt index 1e66336..0fcd649 100644 --- a/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPIM/CMakeLists.txt @@ -3,7 +3,6 @@ mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(SpatialToPIMIncGen) add_onnx_mlir_library(OMSpatialToPIM - SpatialToPIMPass.hpp SpatialToPIMPass.cpp SpatialToPIMCommon.cpp diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td b/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td index d6e2e1a..4dc16eb 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIM.td @@ -3,10 +3,18 @@ #ifndef OP_BASE include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "src/Dialect/ONNX/ONNX.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" -include "src/Accelerators/PIM/Dialect/PIM/Pim.td" +include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE +def onnxToPimTransposeOp : Pat< + (ONNXTransposeOp:$srcOpRes $data, $perms), + (PimTransposeOp $data, $perms, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + def spatToPimVMMOp : Pat< (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), (PimVMMOp $weightIndex, $vector, @@ -25,4 +33,4 @@ def spatToPimVAddOp : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -#endif // SPATIAL_TO_PIM \ No newline at end of file +#endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp index 6d2e319..958e6c8 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.cpp @@ -8,6 +8,7 @@ #include "SpatialToPIMCommon.hpp" using namespace llvm; +using namespace mlir; namespace onnx_mlir { @@ -90,8 +91,8 @@ Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Opera auto resultShapedType = cast(resultType); rewriter.setInsertionPoint(operation); - return rewriter.create( - operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); + return tensor::EmptyOp::create( + rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp index fcca029..8398a79 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp @@ -4,8 +4,6 @@ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -using namespace mlir; - namespace onnx_mlir { /** @@ -20,10 +18,10 @@ namespace onnx_mlir { * \param inputShape The ShapedType of the ExtractSliceOp's input tensor * \return The actual offset of the ExtractSliceOp. */ -size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape); +size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); template -size_t rangeLength(const iterator_range range) { +size_t rangeLength(const mlir::iterator_range range) { return std::distance(range.begin(), range.end()); } @@ -35,16 +33,16 @@ size_t rangeLength(const iterator_range range) { * @return The earliest user operation that uses the given value within the * current block. */ -Operation* getEarliestUserWithinBlock(Value value); +mlir::Operation* getEarliestUserWithinBlock(mlir::Value value); -SmallVector getOpOperandsSortedByUses(Operation* operation); +mlir::SmallVector getOpOperandsSortedByUses(mlir::Operation* operation); -Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation); +mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation); -static bool isMemoryContiguous(const ArrayRef srcShape, - const ArrayRef offsets, - const ArrayRef sizes, - const ArrayRef strides) { +static bool isMemoryContiguous(const mlir::ArrayRef srcShape, + const mlir::ArrayRef offsets, + const mlir::ArrayRef sizes, + const mlir::ArrayRef strides) { // Check that all strides are 1 if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; @@ -99,10 +97,13 @@ static bool isMemoryContiguous(const ArrayRef srcShape, return true; } -inline tensor::EmptyOp createEmptyTensorFromShaped(IRRewriter& rewriter, Location loc, ShapedType shapedType) { - return rewriter.create(loc, shapedType.getShape(), shapedType.getElementType()); +inline mlir::tensor::EmptyOp +createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { + return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType()); } -inline bool isAConcatOp(Operation* op) { return isa(op) || isa(op); } +inline bool isAConcatOp(mlir::Operation* op) { + return isa(op) || isa(op); +} } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index a64c26b..13fd2f7 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -17,12 +17,64 @@ #include #include -#include "SpatialToPIMPass.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; -using namespace spat_to_pim; + +namespace onnx_mlir { + +namespace { + +#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" + +struct SpatialToPIMPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass) + StringRef getArgument() const override { return "convert-spatial-to-pim"; } + StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } + + SpatialToPIMPass() = default; + SpatialToPIMPass(const SpatialToPIMPass& pass) {} + + void runOnOperation() final; + +private: + SmallVector outputTensors; + size_t coreId = 0; + SmallVector operationsToRemove; + + void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); + + void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); + + void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); + void addReceiveOps(Value& channelSourceOp, + spatial::SpatChannelNewOp& channel, + Type& channelTensorType, + bool& useBroadcastOp, + IRRewriter& rewriter); + void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, + unsigned int argIndex, + spatial::SpatChannelNewOp& channel, + Type& tensorType, + bool useBroadcastOp, + IRRewriter& rewriter); + + void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); + + void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); + + void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); +}; + +} // namespace void SpatialToPIMPass::runOnOperation() { coreId = 1; @@ -35,14 +87,17 @@ void SpatialToPIMPass::runOnOperation() { RewritePatternSet patterns(ctx); populateWithGenerated(patterns); - if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } - func::FuncOp funcOp = *moduleOp.getOps().begin(); - if (!funcOp) - llvm_unreachable("No FuncOp found in the begin of module"); + auto entryFunc = getPimEntryFunc(moduleOp); + if (failed(entryFunc)) { + signalPassFailure(); + return; + } + func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); auto returnOp = cast(funcOp.front().getTerminator()); @@ -260,7 +315,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew for (auto returnValue : returnOp->getOperands()) { Operation* returnValueDefiningOp = returnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { - assert(!returnValueDefiningOp->hasAttr("weightAlways")); + assert(!hasWeightAlways(returnValueDefiningOp)); outputTensors.push_back(returnValue); } else { @@ -487,3 +542,7 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getData()); } } + +std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp deleted file mode 100644 index 31c88a2..0000000 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.hpp +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Pass/PimPasses.hpp" -#include "src/Compiler/CompilerOptions.hpp" - -namespace onnx_mlir { - -namespace spat_to_pim { - -#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" - -struct SpatialToPIMPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass) - StringRef getArgument() const override { return "convert-spatial-to-pim"; } - StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } - - SpatialToPIMPass() = default; - SpatialToPIMPass(const SpatialToPIMPass& pass) {} - - void runOnOperation() final; - -private: - SmallVector outputTensors; - size_t coreId = 0; - SmallVector operationsToRemove; - - void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); - - void allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); - - void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); - void addReceiveOps(Value& channelSourceOp, - spatial::SpatChannelNewOp& channel, - Type& channelTensorType, - bool& useBroadcastOp, - IRRewriter& rewriter); - void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, - unsigned int argIndex, - spatial::SpatChannelNewOp& channel, - Type& tensorType, - bool useBroadcastOp, - IRRewriter& rewriter); - - void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); - - void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); - - void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); -}; - -} // namespace spat_to_pim - -std::unique_ptr createSpatialToPIMPass() { return std::make_unique(); } - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/CMakeLists.txt b/src/PIM/Dialect/CMakeLists.txt index 15cf316..a45bf7a 100644 --- a/src/PIM/Dialect/CMakeLists.txt +++ b/src/PIM/Dialect/CMakeLists.txt @@ -1,2 +1,2 @@ -add_subdirectory(PIM) -add_subdirectory(Spatial) \ No newline at end of file +add_subdirectory(Pim) +add_subdirectory(Spatial) diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp deleted file mode 100644 index 16a3a42..0000000 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "mlir/IR/DialectRegistry.h" - -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace pim { - -void registerOpBufferizationInterfaces(DialectRegistry& registry); - -} // namespace pim -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp b/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp deleted file mode 100644 index 3e5e327..0000000 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.hpp +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include "mlir/Pass/Pass.h" - -#include "Dialect/PIM/PimOps.hpp" -#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" -#include "src/Accelerators/PIM/Pass/PimPasses.hpp" -#include "src/Compiler/CompilerOptions.hpp" - -namespace onnx_mlir { - -namespace pim { - -#include "Dialect/PIM/Transforms/Bufferization/PimBufferization.hpp.inc" - -struct PimBufferizationPass : PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) - StringRef getArgument() const override { return "bufferize-pim"; } - StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; } - - PimBufferizationPass() = default; - PimBufferizationPass(const PimBufferizationPass& pass) {} - - void runOnOperation() final; - -private: - void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const; -}; - -} // namespace pim - -std::unique_ptr createBufferizePimPass() { return std::make_unique(); } - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/CMakeLists.txt b/src/PIM/Dialect/Pim/CMakeLists.txt similarity index 100% rename from src/PIM/Dialect/PIM/CMakeLists.txt rename to src/PIM/Dialect/Pim/CMakeLists.txt diff --git a/src/PIM/Dialect/PIM/Pim.td b/src/PIM/Dialect/Pim/Pim.td similarity index 91% rename from src/PIM/Dialect/PIM/Pim.td rename to src/PIM/Dialect/Pim/Pim.td index e7c4603..273288d 100644 --- a/src/PIM/Dialect/PIM/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -175,7 +175,33 @@ def PimMemCopyOp: PimOp<"memcp", [DestinationStyleOpInterface]> { }]; } -// Computation +// Algebra + +def PimTransposeOp: PimOp<"transpose", [DestinationStyleOpInterface]> { + let description = [{ + Matrix transpose + }]; + + let arguments = (ins + PimTensor: $data, + I64ArrayAttr: $perms, + PimTensor: $outBuf + ); + + let results = (outs + PimTensor: $outRes + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutBufMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $data `,` $outBuf `)` attr-dict `:` `(` type($data) `,` type($outBuf) `)` `->` type($outRes) + }]; +} def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { let description = [{ @@ -197,6 +223,10 @@ def PimVMMOp: PimOp<"vmm", [DestinationStyleOpInterface]> { return getOutBufMutable(); } }]; + + let assemblyFormat = [{ + `(` $vectorInput `,` $outBuf `)` attr-dict `:` `(` type($vectorInput) `,` type($outBuf) `)` `->` type($outRes) + }]; } def PimMVMOp: PimOp<"mvm", [DestinationStyleOpInterface]> { diff --git a/src/PIM/Dialect/PIM/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp similarity index 84% rename from src/PIM/Dialect/PIM/PimOps.cpp rename to src/PIM/Dialect/Pim/PimOps.cpp index ebaed4a..3ea7070 100644 --- a/src/PIM/Dialect/PIM/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -10,7 +10,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -20,7 +20,7 @@ namespace pim { void PimDialect::initialize() { addOperations< #define GET_OP_LIST -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc" >(); } @@ -45,5 +45,5 @@ POPULATE_DEPENDENCIES(PimVExpOp) //===----------------------------------------------------------------------===// #define GET_OP_CLASSES -#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.cpp.inc" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.cpp.inc" +#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.cpp.inc" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.cpp.inc" diff --git a/src/PIM/Dialect/PIM/PimOps.hpp b/src/PIM/Dialect/Pim/PimOps.hpp similarity index 78% rename from src/PIM/Dialect/PIM/PimOps.hpp rename to src/PIM/Dialect/Pim/PimOps.hpp index 4811ddb..7e9145a 100644 --- a/src/PIM/Dialect/PIM/PimOps.hpp +++ b/src/PIM/Dialect/Pim/PimOps.hpp @@ -12,7 +12,7 @@ #include /// Include the auto-generated header files containing the declarations -#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" +#include "src/Accelerators/PIM/Dialect/Pim/PimDialect.hpp.inc" #define GET_OP_CLASSES -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp.inc" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp.inc" diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt similarity index 94% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt rename to src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt index eda1784..1f17eec 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt @@ -3,7 +3,6 @@ mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(PimBufferizationIncGen) add_onnx_mlir_library(OMPimBufferization - PimBufferizationPass.hpp PimBufferizationPass.cpp OpBufferizationInterfaces.hpp OpBufferizationInterfaces.cpp diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp similarity index 84% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp rename to src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp index d6480fe..00ab90f 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp @@ -1,4 +1,4 @@ -#include "Dialect/PIM/Transforms/Bufferization/Common.hpp" +#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" using namespace mlir; diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp similarity index 59% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp rename to src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp index 5bbd3ba..961f724 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/Common.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp @@ -2,12 +2,10 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" -using namespace mlir; - namespace onnx_mlir { namespace pim { -IntegerAttr getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref); +mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref); } // namespace pim } // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp similarity index 87% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp rename to src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 64109fd..acffded 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -4,7 +4,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "OpBufferizationInterfaces.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; using namespace bufferization; @@ -76,6 +76,32 @@ struct MemCopyDevToHostOpInterface } }; +struct TransposeOpBufferizeInterface +: DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto transposeOp = cast(op); + + auto dataOpt = getBuffer(rewriter, transposeOp.getData(), options, state); + if (failed(dataOpt)) + return failure(); + + auto outBufOpt = getBuffer(rewriter, transposeOp.getOutBuf(), options, state); + if (failed(outBufOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outBufOpt->getType(), *dataOpt, transposeOp.getPerms(), *outBufOpt); + return success(); + } +}; + struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -176,6 +202,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) { PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); + PimTransposeOp::attachInterface(*ctx); PimVMMOp::attachInterface(*ctx); PimMVMOp::attachInterface(*ctx); PimVAddOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp new file mode 100644 index 0000000..b60e3ef --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "mlir/IR/DialectRegistry.h" + +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +namespace onnx_mlir { +namespace pim { + +void registerOpBufferizationInterfaces(mlir::DialectRegistry& registry); + +} // namespace pim +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td similarity index 84% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td rename to src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td index 78a9c03..bc920e3 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferization.td +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferization.td @@ -4,7 +4,7 @@ #ifndef OP_BASE include "mlir/IR/PatternBase.td" include "mlir/Dialect/MemRef/IR/MemRefOps.td" -include "src/Accelerators/PIM/Dialect/PIM/Pim.td" +include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE def memrefCopyToPimMemCopyOp : Pat< @@ -16,4 +16,4 @@ def memrefCopyToPimMemCopyOp : Pat< (returnType $dst)) >; -#endif // PIM_BUFFERIZATION \ No newline at end of file +#endif // PIM_BUFFERIZATION diff --git a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp similarity index 69% rename from src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp rename to src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 7f5eb82..e5badef 100644 --- a/src/PIM/Dialect/PIM/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -7,12 +7,37 @@ #include "Common/PIMCommon.hpp" #include "Compiler/PimCodeGen.hpp" -#include "PimBufferizationPass.hpp" +#include "Dialect/Pim/PimOps.hpp" +#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; using namespace pim; +namespace onnx_mlir { + +namespace { + +#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc" + +struct PimBufferizationPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) + StringRef getArgument() const override { return "bufferize-pim"; } + StringRef getDescription() const override { return "Bufferize PIM and Spatial ops."; } + + PimBufferizationPass() = default; + PimBufferizationPass(const PimBufferizationPass& pass) {} + + void runOnOperation() final; + +private: + void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const; +}; + +} // namespace + void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); @@ -68,15 +93,18 @@ void PimBufferizationPass::runOnOperation() { } void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { - MLIRContext* ctx = funcOp.getContext(); funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { bool isAlwaysWeight = !getGlobalOp->getUsers().empty() && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) { - auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); + auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); assert("Weights must be constants" && globalMemrefOp.getConstant()); - getGlobalOp->setAttr("weightAlways", UnitAttr::get(ctx)); - globalMemrefOp->setAttr("weightAlways", UnitAttr::get(ctx)); + markWeightAlways(getGlobalOp); + markWeightAlways(globalMemrefOp); } }); } + +std::unique_ptr createBufferizePimPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 35e0176..82dae7e 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -25,7 +25,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() { LogicalResult SpatImgConcatOp::verify() { auto imgShape = mlir::cast(getType()); - size_t img_w = GET_IMAGE_WIDTH(imgShape); - size_t img_h = GET_IMAGE_HEIGHT(imgShape); - size_t img_c = GET_IMAGE_CHANNEL(imgShape); + size_t img_w = getImageWidth(imgShape); + size_t img_h = getImageHeight(imgShape); + size_t img_c = getImageChannel(imgShape); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTileRest = img_c % crossbarSize; @@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() { return emitError("Invalid input type, must be ShapedType"); // N == W == H == 1 - if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1) + if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1) return emitError("Invalid input shape: N,W,H must all be 1"); - size_t inputChannels = GET_IMAGE_CHANNEL(inputShape); + size_t inputChannels = getImageChannel(inputShape); // Check the number of channels in this tile are correct: // - CASE1: last tile of pixel, if there is some rest it must match that @@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() { Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { auto operands = getOperands(); auto imgShape = mlir::cast(getType()); - size_t img_w = GET_IMAGE_WIDTH(imgShape); - size_t img_h = GET_IMAGE_HEIGHT(imgShape); - size_t img_c = GET_IMAGE_CHANNEL(imgShape); + size_t img_w = getImageWidth(imgShape); + size_t img_h = getImageHeight(imgShape); + size_t img_c = getImageChannel(imgShape); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index b8d3934..3fcf477 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -18,7 +18,7 @@ #include #include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp index 3fa6ceb..9013e74 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp @@ -4,14 +4,12 @@ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -using namespace mlir; - namespace onnx_mlir { namespace spatial { -void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry); -void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry); +void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Pass/CountInstructionPass.cpp b/src/PIM/Pass/CountInstructionPass.cpp index a4e0095..1ae5ef3 100644 --- a/src/PIM/Pass/CountInstructionPass.cpp +++ b/src/PIM/Pass/CountInstructionPass.cpp @@ -1,7 +1,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerUtils.hpp" diff --git a/src/PIM/Pass/PimFoldHostConstantsPass.cpp b/src/PIM/Pass/PimFoldHostConstantsPass.cpp new file mode 100644 index 0000000..f873c09 --- /dev/null +++ b/src/PIM/Pass/PimFoldHostConstantsPass.cpp @@ -0,0 +1,181 @@ +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" + +#include + +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { + auto tensorType = dyn_cast(denseAttr.getType()); + if (!tensorType) + return failure(); + + int64_t rank = tensorType.getRank(); + if (static_cast(perms.size()) != rank) + return failure(); + + llvm::SmallBitVector seen(rank); + SmallVector transposedShape; + transposedShape.reserve(rank); + for (int64_t perm : perms) { + if (perm < 0 || perm >= rank || seen.test(perm)) + return failure(); + seen.set(perm); + transposedShape.push_back(tensorType.getShape()[perm]); + } + + auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType()); + if (denseAttr.isSplat()) + return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); + + SmallVector originalValues(denseAttr.getValues()); + SmallVector transposedValues(originalValues.size()); + + SmallVector originalStrides(rank, 1); + SmallVector transposedStrides(rank, 1); + for (int64_t dim = rank - 2; dim >= 0; --dim) { + originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1]; + transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1]; + } + + SmallVector originalIndices(rank); + SmallVector transposedIndices(rank); + for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { + int64_t remaining = static_cast(linearIndex); + for (int64_t dim = 0; dim < rank; ++dim) { + originalIndices[dim] = remaining / originalStrides[dim]; + remaining %= originalStrides[dim]; + } + + for (int64_t dim = 0; dim < rank; ++dim) + transposedIndices[dim] = originalIndices[perms[dim]]; + + int64_t transposedLinearIndex = 0; + for (int64_t dim = 0; dim < rank; ++dim) + transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim]; + + transposedValues[transposedLinearIndex] = value; + } + + return DenseElementsAttr::get(transposedType, transposedValues); +} + +struct FoldConstantTransposePattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override { + auto resultType = dyn_cast(transposeOp.getOutRes().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + auto sourceGetGlobal = transposeOp.getData().getDefiningOp(); + if (!sourceGetGlobal) + return failure(); + + auto moduleOp = transposeOp->getParentOfType(); + if (!moduleOp) + return failure(); + + auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal); + if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue()) + return failure(); + + auto denseAttr = dyn_cast(*sourceGlobal.getInitialValue()); + if (!denseAttr) + return failure(); + + SmallVector perms; + perms.reserve(transposeOp.getPerms().size()); + for (IntegerAttr attr : transposeOp.getPerms().getAsRange()) + perms.push_back(attr.getInt()); + FailureOr transposedAttr = transposeDenseElements(denseAttr, perms); + if (failed(transposedAttr)) + return failure(); + + auto transposedShape = cast(transposedAttr->getType()).getShape(); + if (!llvm::equal(transposedShape, resultType.getShape())) + return failure(); + + MemRefType globalType = resultType; + + auto globalName = sourceGlobal.getName().str() + "__folded_transpose"; + unsigned suffix = 0; + while (moduleOp.lookupSymbol(globalName)) + globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix); + + auto visibility = rewriter.getStringAttr("private"); + OpBuilder moduleBuilder(moduleOp.getBodyRegion()); + moduleBuilder.setInsertionPointToStart(moduleOp.getBody()); + auto newGlobal = memref::GlobalOp::create(moduleBuilder, + transposeOp.getLoc(), + globalName, + visibility, + globalType, + *transposedAttr, + /*constant=*/true, + sourceGlobal.getAlignmentAttr()); + + rewriter.setInsertionPoint(transposeOp); + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName()); + + bool isAlwaysWeight = + !transposeOp->getUsers().empty() + && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa(user); }); + if (isAlwaysWeight) { + markWeightAlways(newGlobal); + markWeightAlways(newGetGlobal); + } + + rewriter.replaceOp(transposeOp, newGetGlobal.getResult()); + return success(); + } +}; + +struct PimFoldHostConstantsPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass) + + StringRef getArgument() const override { return "fold-pim-host-constants-pass"; } + StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } + + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet owningPatterns(context); + for (auto* dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + owningPatterns.add(context); + patterns = std::make_shared(std::move(owningPatterns)); + return success(); + } + + void runOnOperation() override { + GreedyRewriteConfig config; + config.enableFolding(); + if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) + signalPassFailure(); + } + + std::shared_ptr patterns; +}; + +} // namespace + +std::unique_ptr createPimFoldHostConstantsPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimHostVerificationPass.cpp b/src/PIM/Pass/PimHostVerificationPass.cpp new file mode 100644 index 0000000..9b1dbc0 --- /dev/null +++ b/src/PIM/Pass/PimHostVerificationPass.cpp @@ -0,0 +1,173 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static bool isAddressOnlyHostOp(Operation* op) { + return isa(op); +} + +static bool isHostAddressableValue(Value value) { + while (true) { + if (auto blockArg = dyn_cast(value)) + return isa(blockArg.getOwner()->getParentOp()); + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return false; + + if (isa(definingOp)) + return true; + + if (auto subviewOp = dyn_cast(definingOp)) { + value = subviewOp.getSource(); + continue; + } + if (auto castOp = dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + + return false; + } +} + +struct PimHostVerificationPass : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass) + + StringRef getArgument() const override { return "verify-pim-host-pass"; } + StringRef getDescription() const override { + return "Verify that no runtime host-side code remains in bufferized PIM IR"; + } + + PimHostVerificationPass() {} + PimHostVerificationPass(const PimHostVerificationPass& pass) {} + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + bool hasFailure = false; + + for (func::FuncOp funcOp : moduleOp.getOps()) { + if (funcOp.isExternal()) + continue; + + for (Operation& op : funcOp.getBody().front().getOperations()) { + if (auto coreOp = dyn_cast(&op)) { + if (failed(verifyCoreWeights(moduleOp, coreOp))) + hasFailure = true; + continue; + } + + if (auto returnOp = dyn_cast(&op)) { + if (failed(verifyReturnOp(returnOp))) + hasFailure = true; + continue; + } + + if (!isAddressOnlyHostOp(&op)) { + op.emitOpError("illegal host-side runtime op remains after PIM bufferization; " + "fold it to constants or lower it into pim.core"); + hasFailure = true; + continue; + } + + if (failed(verifyAddressOnlyHostOp(&op))) + hasFailure = true; + } + } + + if (hasFailure) + signalPassFailure(); + } + +private: + static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) { + bool hasFailure = false; + for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { + auto getGlobalOp = weight.getDefiningOp(); + if (!getGlobalOp) { + coreOp.emitOpError() << "weight #" << weightIndex + << " must be materialized as memref.get_global before JSON codegen"; + hasFailure = true; + continue; + } + + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp) { + coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; + hasFailure = true; + continue; + } + + if (!globalOp.getConstant() || !globalOp.getInitialValue()) { + coreOp.emitOpError() << "weight #" << weightIndex + << " must come from a constant memref.global with an initial value"; + hasFailure = true; + } + } + + return success(!hasFailure); + } + + static LogicalResult verifyReturnOp(func::ReturnOp returnOp) { + bool hasFailure = false; + for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) { + if (!isHostAddressableValue(operand)) { + returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage"; + hasFailure = true; + } + } + return success(!hasFailure); + } + + static LogicalResult verifyAddressOnlyHostOp(Operation* op) { + if (auto subviewOp = dyn_cast(op)) + return verifyAddressOnlySource(op, subviewOp.getSource()); + if (auto castOp = dyn_cast(op)) + return verifyAddressOnlySource(op, castOp.getSource()); + if (auto collapseOp = dyn_cast(op)) + return verifyAddressOnlySource(op, collapseOp.getSrc()); + if (auto expandOp = dyn_cast(op)) + return verifyAddressOnlySource(op, expandOp.getSrc()); + return success(); + } + + static LogicalResult verifyAddressOnlySource(Operation* op, Value source) { + if (isHostAddressableValue(source)) + return success(); + + op->emitOpError("depends on a value that still requires host-side execution"); + return failure(); + } +}; + +} // namespace + +std::unique_ptr createPimHostVerificationPass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimPasses.hpp b/src/PIM/Pass/PimPasses.hpp index d548f40..781c073 100644 --- a/src/PIM/Pass/PimPasses.hpp +++ b/src/PIM/Pass/PimPasses.hpp @@ -3,23 +3,26 @@ #include "mlir/Pass/Pass.h" #include - -using namespace mlir; +#include namespace onnx_mlir { -std::unique_ptr createONNXToSpatialPass(); +std::unique_ptr createONNXToSpatialPass(); -std::unique_ptr createSpatialToGraphvizPass(); +std::unique_ptr createSpatialToGraphvizPass(); -std::unique_ptr createSpatialToPIMPass(); +std::unique_ptr createSpatialToPIMPass(); -std::unique_ptr createBufferizePimPass(); +std::unique_ptr createBufferizePimPass(); -std::unique_ptr createEmitPimJsonPass(); +std::unique_ptr createPimFoldHostConstantsPass(); -std::unique_ptr createMessagePass(std::string message); +std::unique_ptr createPimHostVerificationPass(); -std::unique_ptr createCountInstructionPass(); +std::unique_ptr createEmitPimJsonPass(); + +std::unique_ptr createMessagePass(std::string message); + +std::unique_ptr createCountInstructionPass(); } // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index d50dd30..997370f 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -4,6 +4,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" @@ -12,8 +13,8 @@ #include "llvm/Support/Debug.h" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/PIM/Transforms/Bufferization/OpBufferizationInterfaces.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" @@ -40,28 +41,27 @@ PimAccelerator::PimAccelerator() acceleratorTargets.push_back(this); } -PimAccelerator::~PimAccelerator() { delete instance; } - uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } -void PimAccelerator::addPasses(OwningOpRef& module, - PassManager& pm, +void PimAccelerator::addPasses(mlir::OwningOpRef& module, + mlir::PassManager& pm, EmissionTargetType& emissionTarget, std::string outputNameNoExt) const { LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); addPassesPim(module, pm, emissionTarget, outputNameNoExt); } -void PimAccelerator::registerDialects(DialectRegistry& registry) const { +void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const { LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); - registry.insert(); - registry.insert(); - registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); registry.insert(); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - arith::registerBufferizableOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); + mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry); + mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); + mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); pim::registerOpBufferizationInterfaces(registry); @@ -73,6 +73,8 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPIMPass); registerPass(createBufferizePimPass); + registerPass(createPimFoldHostConstantsPass); + registerPass(createPimHostVerificationPass); registerPass(createEmitPimJsonPass); } @@ -81,26 +83,26 @@ void PimAccelerator::configurePasses() const { // TODO: This does nothing for now. } -MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const { +mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const { // Do not convert tensor types to memref types. return nullptr; } -void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const { +void PimAccelerator::conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const { target.addLegalDialect(); } -void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns, - TypeConverter& typeConverter, - MLIRContext* ctx) const { +void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns, + mlir::TypeConverter& typeConverter, + mlir::MLIRContext* ctx) const { // TODO: Add patterns for conversion } -void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {} +void PimAccelerator::conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const {} -void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns, - LLVMTypeConverter& typeConverter, - MLIRContext* ctx) const { +void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns, + mlir::LLVMTypeConverter& typeConverter, + mlir::MLIRContext* ctx) const { // We should not need this, since we offload it all to PIM. } diff --git a/src/PIM/PimAccelerator.hpp b/src/PIM/PimAccelerator.hpp index d7f61df..22ff864 100644 --- a/src/PIM/PimAccelerator.hpp +++ b/src/PIM/PimAccelerator.hpp @@ -18,8 +18,6 @@ public: PimAccelerator(PimAccelerator&) = delete; void operator=(const PimAccelerator&) = delete; - ~PimAccelerator(); - /// Creates an instance on the first invocation. Subsequent invocations /// return the existing instance. static PimAccelerator* getInstance(); diff --git a/validation/gen_network_runner.py b/validation/gen_network_runner.py index 6b980c5..7966fc8 100644 --- a/validation/gen_network_runner.py +++ b/validation/gen_network_runner.py @@ -102,46 +102,13 @@ def gen_c(inputs, outputs, entry, so_name): if(!in{i}_tensor){{fprintf(stderr,"ERROR: omTensorCreateWithOwnership failed for input {i}.\\n");return 2;}} """)) - # Output printing + optional per-output CSV dump - out_blocks=[] + # Optional per-output CSV dump csv_write_blocks=[] for oi,name,et,shape in outputs: if et not in DTYPES: raise ValueError(f"Unsupported dtype for output '{name}': {et}") cty, pfmt, _ = DTYPES[et] safe = esc(name) - out_blocks.append(textwrap.dedent(f""" - // ---- Output {oi}: "{safe}" ---- - {{ - OMTensor *t = omTensorListGetOmtByIndex(out_list, {oi}); - int64_t rank = omTensorGetRank(t); - int64_t const *shape = omTensorGetShape(t); - long long numel = 1; for (int64_t k=0;k.csv" @@ -227,9 +194,6 @@ int main(int argc, char **argv) {{ OMTensorList *out_list = {entry}(in_list); if(!out_list){{fprintf(stderr,"ERROR: model returned NULL.\\n");omTensorListDestroy(in_list);return 3;}} - // ---- Print full outputs ---- -{"".join(out_blocks)} - // ---- Optional per-output CSV dump ---- {"".join(csv_write_blocks)} @@ -240,7 +204,7 @@ int main(int argc, char **argv) {{ }} """ -def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None): +def gen_network_runner(network_onnx, network_so, onnx_include_dir, entry="run_main_graph", out=None, verbose=True): ins, outs = onnx_io(network_onnx) out_c = out or "runner.c" so_abs = os.path.abspath(network_so) @@ -260,8 +224,9 @@ set_target_properties(model_so PROPERTIES IMPORTED_LOCATION {esc(so_abs)}) target_link_libraries({pathlib.Path(out_c).stem} PUBLIC model_so) """ pathlib.Path(out_c).with_name("CMakeLists.txt").write_text(cmake) - print(f"[OK] Wrote {out_c}") - print("[OK] Wrote CMakeLists.txt") + if verbose: + print(f"[OK] Wrote {out_c}") + print("[OK] Wrote CMakeLists.txt") if __name__=="__main__": ap=argparse.ArgumentParser() diff --git a/validation/raptor.py b/validation/raptor.py index 29030e2..8b7d8cb 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -1,9 +1,10 @@ import subprocess from pathlib import Path from colorama import Fore, Style +from subprocess_utils import run_command_with_reporter -def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count): +def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, crossbar_count, reporter=None): # Define the arguments, with the possibility to set crossbar size and count args = [ network_path, @@ -14,16 +15,14 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, crossbar_size, cro f"--crossbar-count={crossbar_count}", ] - # Run the executable with the arguments try: - result = subprocess.run( + run_command_with_reporter( [str(raptor_onnx_path)] + [str(arg) for arg in args], - check=True, - capture_output=True, - text=True, + reporter=reporter, ) - print(result.stdout + Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) - except subprocess.CalledProcessError as e: - print(Fore.RED + "Error executing ONNX-MLIR:") - print(e.stderr + Style.RESET_ALL) + if reporter is None: + print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL) + except subprocess.CalledProcessError: + if reporter is None: + print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL) raise diff --git a/validation/subprocess_utils.py b/validation/subprocess_utils.py new file mode 100644 index 0000000..fb0ad29 --- /dev/null +++ b/validation/subprocess_utils.py @@ -0,0 +1,70 @@ +import errno +import os +import pty +import selectors +import subprocess + + +def _read_chunk(fd, treat_eio_as_eof=False): + try: + return os.read(fd, 4096) + except OSError as exc: + if treat_eio_as_eof and exc.errno == errno.EIO: + return b"" + raise + + +def _stream_output(fd, process, reporter, treat_eio_as_eof=False): + selector = selectors.DefaultSelector() + + try: + selector.register(fd, selectors.EVENT_READ) + + while selector.get_map(): + for key, _ in selector.select(): + data = _read_chunk(key.fileobj, treat_eio_as_eof=treat_eio_as_eof) + if not data: + selector.unregister(key.fileobj) + os.close(key.fileobj) + continue + + reporter._clear() + os.write(1, data) + reporter._render() + finally: + selector.close() + + return_code = process.wait() + if return_code != 0: + raise subprocess.CalledProcessError(return_code, process.args) + + +def run_command_with_reporter(cmd, cwd=None, reporter=None): + if reporter is None: + subprocess.run(cmd, cwd=cwd, check=True) + return + + try: + master_fd, slave_fd = pty.openpty() + except OSError: + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert process.stdout is not None + _stream_output(process.stdout.fileno(), process, reporter) + return + + try: + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=slave_fd, + stderr=slave_fd, + ) + finally: + os.close(slave_fd) + + _stream_output(master_fd, process, reporter, treat_eio_as_eof=True) diff --git a/validation/validate.py b/validation/validate.py index 68dd4ea..697fca5 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 import argparse +import subprocess import sys from pathlib import Path from colorama import Style, Fore -from validate_one import validate_network +from validate_one import ProgressReporter, validate_network def main(): @@ -34,32 +35,48 @@ def main(): print(Fore.YELLOW + f"No .onnx files found under {operations_dir}" + Style.RESET_ALL) sys.exit(1) - print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate.\n" + Style.RESET_ALL) + print(Style.BRIGHT + f"Found {len(onnx_files)} ONNX file(s) to validate." + Style.RESET_ALL) + print(f"Operations root: {operations_dir}") + print("=" * 72) results = {} # relative_path -> passed - for onnx_path in onnx_files: + reporter = ProgressReporter(len(onnx_files)) + for index, onnx_path in enumerate(onnx_files, start=1): rel = onnx_path.relative_to(operations_dir) - header = f"{'=' * 60}\n Validating: {rel}\n{'=' * 60}" - print(Style.BRIGHT + Fore.CYAN + header + Style.RESET_ALL) + try: + passed = validate_network( + onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, + crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, + threshold=a.threshold, + reporter=reporter, + model_index=index, + model_total=len(onnx_files), + ) + results[str(rel)] = passed + except (subprocess.CalledProcessError, Exception): + results[str(rel)] = False - passed = validate_network( - onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, - crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, - threshold=a.threshold, - ) - - results[str(rel)] = passed + reporter.finish() # Summary - n_passed = sum(results.values()) + n_passed = sum(1 for passed in results.values() if passed) n_total = len(results) - print("\n" + Style.BRIGHT + "=" * 60) - print(" Summary") - print("=" * 60 + Style.RESET_ALL) + status_width = len("Result") + path_width = max(len("Operation"), *(len(rel) for rel in results)) + separator = f"+-{'-' * path_width}-+-{'-' * status_width}-+" + + print("\n" + Style.BRIGHT + Fore.CYAN + "Summary" + Style.RESET_ALL) + print(separator) + print(f"| {'Operation'.ljust(path_width)} | {'Result'.ljust(status_width)} |") + print(separator) for rel, passed in results.items(): - status = Fore.GREEN + "PASS" if passed else Fore.RED + "FAIL" - print(f" {rel}: {status}" + Style.RESET_ALL) - print(Style.BRIGHT + f"\n {n_passed}/{n_total} passed." + Style.RESET_ALL) + plain_status = "PASS" if passed else "FAIL" + status = Fore.GREEN + plain_status.ljust(status_width) + Style.RESET_ALL if passed else \ + Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL + print(f"| {rel.ljust(path_width)} | {status} |") + print(separator) + print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL) + print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL) sys.exit(0 if n_passed == n_total else 1) diff --git a/validation/validate_one.py b/validation/validate_one.py index d31fdb6..8f6b6e7 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -2,16 +2,114 @@ import argparse import json import numpy as np import subprocess +import shutil +import sys from pathlib import Path from colorama import Style, Fore from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP from raptor import compile_with_raptor from gen_network_runner import gen_network_runner +from subprocess_utils import run_command_with_reporter -def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir): - subprocess.run([raptor_path, network_onnx_path, "--EmitONNXIR"], check=True) - subprocess.run([raptor_path, network_onnx_path], check=True) +STAGE_COUNT = 6 + + +class ProgressReporter: + def __init__(self, total_models, stages_per_model=STAGE_COUNT): + self.total_models = total_models + self.stages_per_model = stages_per_model + self.total_steps = max(1, total_models * stages_per_model) + self.completed_steps = 0 + self.current_label = "" + self.enabled = True + self.columns = shutil.get_terminal_size((100, 20)).columns + self.suspended = False + + def _clear(self): + if self.enabled: + sys.stdout.write("\033[2K\r") + + def _render(self): + if not self.enabled or self.suspended: + return + bar_width = 24 + filled = int(bar_width * self.completed_steps / self.total_steps) + prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}" + if len(prefix_text) > self.columns: + prefix_text = f"{self.completed_steps}/{self.total_steps}" + + label = f" {self.current_label}" if self.current_label else "" + available_label_width = max(0, self.columns - len(prefix_text)) + label = label[:available_label_width] + + if prefix_text.startswith("["): + bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled)) + prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL + else: + prefix = Fore.CYAN + prefix_text + Style.RESET_ALL + + sys.stdout.write("\r" + prefix + label + Style.RESET_ALL) + sys.stdout.flush() + + def log(self, message="", color=None): + if self.enabled: + self._clear() + if color: + print(color + message + Style.RESET_ALL) + else: + print(message) + self._render() + + def set_stage(self, model_index, model_total, model_name, stage_name): + self.current_label = f"[{model_index}/{model_total}] {model_name} ยท {stage_name}" + self._render() + + def advance(self): + self.completed_steps = min(self.total_steps, self.completed_steps + 1) + self._render() + + def suspend(self): + self.suspended = True + self._clear() + + def resume(self): + self.suspended = False + self._render() + + def finish(self): + if self.enabled: + self.suspended = True + self._clear() + sys.stdout.flush() + + +def run_command(cmd, cwd=None, reporter=None): + run_command_with_reporter(cmd, cwd=cwd, reporter=reporter) + + +def print_stage(reporter, model_index, model_total, model_name, title): + stage_colors = { + "Compile ONNX": Fore.BLUE, + "Build Runner": Fore.MAGENTA, + "Generate Inputs": Fore.YELLOW, + "Run Reference": Fore.GREEN, + "Compile PIM": Fore.CYAN, + "Run Simulator": Fore.MAGENTA, + "Compare Outputs": Fore.YELLOW, + } + color = stage_colors.get(title, Fore.WHITE) + reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL) + reporter.set_stage(model_index, model_total, model_name, title) + + +def print_info(reporter, message): + reporter.log(f" {message}") + + +def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=None): + run_command([raptor_path, network_onnx_path, "--EmitONNXIR"], reporter=reporter) + run_command([raptor_path, network_onnx_path], reporter=reporter) parent = network_onnx_path.parent stem = network_onnx_path.stem so_path = parent / f"{stem}.so" @@ -25,9 +123,9 @@ def compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir) return moved_so, moved_mlir -def build_onnx_runner(source_dir, build_dir): - subprocess.run(["cmake", source_dir], cwd=build_dir, check=True) - subprocess.run(["cmake", "--build", ".", "-j"], cwd=build_dir, check=True) +def build_onnx_runner(source_dir, build_dir, reporter=None): + run_command(["cmake", source_dir], cwd=build_dir, reporter=reporter) + run_command(["cmake", "--build", ".", "-j"], cwd=build_dir, reporter=reporter) return build_dir / "runner" @@ -41,11 +139,12 @@ def build_dump_ranges(config_path, outputs_descriptor): return ",".join(ranges) -def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges): - subprocess.run( +def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): + run_command( ["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], - cwd=simulator_dir, check=True + cwd=simulator_dir, + reporter=reporter, ) @@ -64,24 +163,41 @@ def parse_pim_simulator_outputs(output_bin_path, outputs_descriptor): def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1e-3): all_passed = True + rows = [] for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): csv_name = f"output{oi}_{name}.csv" runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) passed = max_diff <= threshold - status = Fore.GREEN + "[PASS]" if passed else Fore.RED + "[FAIL]" - print(f" {name}: max diff = {max_diff:.6e} {status}" + Style.RESET_ALL) + rows.append((name, f"{max_diff:.6e}", passed)) if not passed: all_passed = False + + name_width = max(len("Output"), *(len(name) for name, _, _ in rows)) + diff_width = max(len("Max diff"), *(len(diff) for _, diff, _ in rows)) + result_width = len("Result") + separator = f" +-{'-' * name_width}-+-{'-' * diff_width}-+-{'-' * result_width}-+" + + print(separator) + print(f" | {'Output'.ljust(name_width)} | {'Max diff'.ljust(diff_width)} | {'Result'} |") + print(separator) + for name, diff_text, passed in rows: + status_text = ("PASS" if passed else "FAIL").ljust(result_width) + status = Fore.GREEN + status_text + Style.RESET_ALL if passed else Fore.RED + status_text + Style.RESET_ALL + print(f" | {name.ljust(name_width)} | {diff_text.ljust(diff_width)} | {status} |") + print(separator) return all_passed def validate_network(network_onnx_path, raptor_path, onnx_include_dir, - simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3): + simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3, + reporter=None, model_index=1, model_total=1): network_onnx_path = Path(network_onnx_path).resolve() raptor_path = Path(raptor_path).resolve() onnx_include_dir = Path(onnx_include_dir).resolve() simulator_dir = Path(simulator_dir).resolve() + owns_reporter = reporter is None + reporter = reporter or ProgressReporter(model_total) workspace_dir = network_onnx_path.parent raptor_dir = workspace_dir / "raptor" @@ -90,40 +206,72 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, Path.mkdir(raptor_dir, exist_ok=True) Path.mkdir(runner_build_dir, parents=True, exist_ok=True) - print(Style.BRIGHT + "\nCompiling the onnx network:" + Style.RESET_ALL) - network_so_path, network_mlir_path = compile_onnx_network(network_onnx_path, raptor_path, raptor_dir, runner_dir) + reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL + + f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}") - print(Style.BRIGHT + "\nGenerating and building the runner:" + Style.RESET_ALL) - gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c") - runner_path = build_onnx_runner(runner_dir, runner_build_dir) + try: + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX") + network_so_path, network_mlir_path = compile_onnx_network( + network_onnx_path, raptor_path, raptor_dir, runner_dir, reporter=reporter) + print_info(reporter, f"MLIR saved to {network_mlir_path}") + print_info(reporter, f"Shared library saved to {network_so_path}") + reporter.advance() - print(Style.BRIGHT + "\nGenerating random inputs:" + Style.RESET_ALL) - inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path) - inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor) - flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs") + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner") + gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False) + runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter) + print_info(reporter, f"Runner built at {runner_path}") + reporter.advance() - print(Style.BRIGHT + "\nRunning inference with the runner:" + Style.RESET_ALL) - out_dir = workspace_dir / "outputs" - Path.mkdir(out_dir, exist_ok=True) - run_cmd = [runner_path, *flags] - run_cmd += ["--save-csv-dir", f"{out_dir}"] - subprocess.run(run_cmd, cwd=runner_build_dir, check=True) + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Generate Inputs") + inputs_descriptor, outputs_descriptor = onnx_io(network_onnx_path) + inputs_list, _inputs_dict = gen_random_inputs(inputs_descriptor) + flags, _files = save_inputs_to_files(network_onnx_path, inputs_list, out_dir=workspace_dir / "inputs") + print_info(reporter, f"Saved {len(inputs_list)} input file(s) to {workspace_dir / 'inputs'}") + reporter.advance() - print(Style.BRIGHT + "\nCompiling for PIM with Raptor:" + Style.RESET_ALL) - compile_with_raptor(network_mlir_path, raptor_path, crossbar_size, crossbar_count) + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Reference") + out_dir = workspace_dir / "outputs" + Path.mkdir(out_dir, exist_ok=True) + run_cmd = [runner_path, *flags] + run_cmd += ["--save-csv-dir", f"{out_dir}"] + run_command(run_cmd, cwd=runner_build_dir, reporter=reporter) + print_info(reporter, f"Reference outputs saved to {out_dir}") + reporter.advance() - print(Style.BRIGHT + "\nRunning PIM simulation:" + Style.RESET_ALL) - pim_dir = raptor_dir / "pim" - write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list) - simulation_dir = workspace_dir / "simulation" - Path.mkdir(simulation_dir, exist_ok=True) - dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor) - output_bin_path = simulation_dir / "out.bin" - run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges) + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") + compile_with_raptor( + network_mlir_path, raptor_path, crossbar_size, crossbar_count, reporter=reporter) + print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") + reporter.advance() - print(Style.BRIGHT + "\nValidating the results:" + Style.RESET_ALL) - sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor) - return validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Run Simulator") + pim_dir = raptor_dir / "pim" + write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", inputs_list) + simulation_dir = workspace_dir / "simulation" + Path.mkdir(simulation_dir, exist_ok=True) + dump_ranges = build_dump_ranges(pim_dir / "config.json", outputs_descriptor) + output_bin_path = simulation_dir / "out.bin" + run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=reporter) + print_info(reporter, f"Simulator output saved to {output_bin_path}") + reporter.advance() + + print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compare Outputs") + sim_arrays = parse_pim_simulator_outputs(output_bin_path, outputs_descriptor) + reporter.suspend() + passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold) + reporter.resume() + reporter.advance() + status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL + reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL) + return passed + except Exception: + reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL) + raise + finally: + reporter.log("=" * 72) + if owns_reporter: + reporter.finish() if __name__ == '__main__':