From bdc4ca33f3dd7ca5f4a833d5ca7ef32540a78663 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Mon, 25 May 2026 18:19:43 +0200 Subject: [PATCH] No extract no more --- .../Conversion/ONNXToSpatial/CMakeLists.txt | 2 +- .../Conversion/ONNXToSpatial/CompileTime.cpp | 67 +++++++++++++------ .../Conversion/ONNXToSpatial/CompileTime.hpp | 7 ++ .../Conversion/ONNXToSpatial/HostLegality.cpp | 34 ---------- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 4 +- .../ONNXToSpatial/ONNXToSpatialVerifier.cpp | 49 ++++++++++++++ ...Legality.hpp => ONNXToSpatialVerifier.hpp} | 2 +- src/PIM/Dialect/Spatial/SpatialOps.cpp | 45 +++++++++---- .../Scheduling/PeftScheduler.cpp | 31 ++++++--- 9 files changed, 160 insertions(+), 81 deletions(-) delete mode 100644 src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp rename src/PIM/Conversion/ONNXToSpatial/{HostLegality.hpp => ONNXToSpatialVerifier.hpp} (64%) diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 776e445..d19c4cd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -5,7 +5,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen) add_pim_library(OMONNXToSpatial ConversionPatterns.cpp CompileTime.cpp - HostLegality.cpp + ONNXToSpatialVerifier.cpp PrePatterns.cpp PostPatterns.cpp Patterns/Math/Conv.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp index 1ae980c..f96061e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp @@ -7,6 +7,9 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" + +#include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -195,66 +198,90 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm: return nullptr; } -static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { + +static std::optional +getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl& visited, size_t chainLength = 0) { if (!op) - return false; + return std::nullopt; if (!visited.insert(op).second) - return true; + return { + {op, chainLength} + }; if (isa(op)) - return true; + return { + {op, chainLength} + }; + + chainLength += 1; if (auto extractOp = dyn_cast(op)) - return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited); + return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt; if (!isStaticTensorResult(op)) - return false; + return std::nullopt; if (auto transposeOp = dyn_cast(op)) - return isCompileTimeOpImpl(transposeOp, visited); + return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength); if (auto collapseShapeOp = dyn_cast(op)) - return isCompileTimeOpImpl(collapseShapeOp,visited); + return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength); if (auto expandShapeOp = dyn_cast(op)) - return isCompileTimeOpImpl(expandShapeOp, visited); + return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength); if (auto extractSliceOp = dyn_cast(op)) - return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited); + return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength) + : std::nullopt; if (auto splatOp = dyn_cast(op)) - return isCompileTimeOpImpl(splatOp, visited); + return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength); if (auto extractRowsOp = dyn_cast(op)) - return isCompileTimeOpImpl(extractRowsOp, visited); + return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength); - if (auto concatOp = dyn_cast(op)){ - bool res = true; - for(auto operandValue : concatOp.getOperands()){ - res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited); - if(!res) break; + if (auto concatOp = dyn_cast(op)) { + std::optional res = {}; + for (auto operandValue : concatOp.getOperands()) { + auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength); + if (!partialRes) + return std::nullopt; + + if (!res) { + res = partialRes; + continue; + } + if(res->chainLength < partialRes->chainLength){ + res = partialRes; + } } return res; } - return false; + return std::nullopt; } } // namespace + + std::optional getCompileTimeSource(Operation* op) { + llvm::SmallPtrSet visited; + return getCompileTimeSourceImpl(op, visited); +} + bool isCompileTimeComputable(Value value) { auto* definingOp = value.getDefiningOp(); if (!definingOp) return false; llvm::SmallPtrSet visited; - return isCompileTimeOpImpl(definingOp, visited); + return getCompileTimeSourceImpl(definingOp, visited).has_value(); } bool isCompileTimeOp(Operation* op) { llvm::SmallPtrSet visited; - return isCompileTimeOpImpl(op, visited); + return getCompileTimeSourceImpl(op, visited).has_value(); } DenseElementsAttr getHostConstDenseElementsAttr(Value value) { diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp index f0574b5..c9ad809 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp @@ -6,6 +6,13 @@ namespace onnx_mlir { +struct CompileTimeSource { + mlir::Operation* source; + size_t chainLength; +}; + +std::optional getCompileTimeSource(mlir::Operation* op); + bool isCompileTimeComputable(mlir::Value value); bool isCompileTimeOp(mlir::Operation* op); diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp deleted file mode 100644 index 9a9c297..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { - pim::CappedDiagnosticReporter diagnostics; - - for (Operation& op : funcOp.getFunctionBody().front()) { - if (isa(&op)) - continue; - if (isCompileTimeOp(&op)) - continue; - - diagnostics.report(&op, [](Operation* illegalOp) { - illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside " - "spat.compute"); - }); - } - - diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures"); - - return success(!diagnostics.hasFailure()); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 007f019..c9f0893 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -14,7 +14,7 @@ #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -220,7 +220,7 @@ void ONNXToSpatialPass::runOnOperation() { wrapTopLevelRuntimeTransposes(*entryFunc); - if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { + if (failed(verifyONNXToSpatial(*entryFunc))) { moduleOp.emitError("ONNX-to-Spatial host legality verification failed"); signalPassFailure(); return; diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp new file mode 100644 index 0000000..aeeb47c --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -0,0 +1,49 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Support/LLVM.h" + +#include "Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { + for (auto extractSlice : func.getOps()) { + auto source = getCompileTimeSource(extractSlice.getOperation()); + if (source && hasWeightAlways(source->source) && source->chainLength > 1) { + + diagnostics.report(extractSlice.getOperation(), + [](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); }); + } + } +} + +LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { + pim::CappedDiagnosticReporter diagnostics; + + for (Operation& op : funcOp.getOps()) { + if (isa(&op)) + continue; + if (isCompileTimeOp(&op)) + continue; + + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError( + "non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); + }); + } + + checkWeightsDirectlyExtracted(funcOp, diagnostics); + + diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed"); + + return success(!diagnostics.hasFailure()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp similarity index 64% rename from src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp rename to src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp index 3521eae..a5fd052 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp @@ -5,6 +5,6 @@ namespace onnx_mlir { -mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp); +mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp); } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 0ced6d2..20a132b 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -1,3 +1,6 @@ +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + #include #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -8,7 +11,7 @@ namespace onnx_mlir { namespace spatial { namespace { -std::optional getBatchBodyArgument(Region& body, unsigned argIdx) { +std::optional getBlockArgument(Region& body, unsigned argIdx) { if (body.empty()) return std::nullopt; @@ -18,7 +21,7 @@ std::optional getBatchBodyArgument(Region& body, unsigned argIdx) return block.getArgument(argIdx); } -std::optional insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) { +std::optional insertBlockArgument(Region& body, unsigned argIdx, Type type, Location loc) { if (body.empty()) return std::nullopt; return body.insertArgument(argIdx, type, loc); @@ -34,21 +37,27 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i } // namespace -std::optional SpatCompute::getWeightArgument(unsigned idx) { - return getBatchBodyArgument(getBody(), idx); -} +std::optional SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); } std::optional SpatCompute::getInputArgument(unsigned idx) { - return getBatchBodyArgument(getBody(), getWeights().size() + idx); + return getBlockArgument(getBody(), getWeights().size() + idx); } std::optional> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) { + if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { + llvm::dbgs() << "Disse netanyao\n"; + auto index = std::distance(getWeights().begin(), existing); + return { + {*existing, *getWeightArgument(index)} + }; + } + unsigned weightCount = getWeights().size(); unsigned inputCount = getInputs().size(); getOperation()->insertOperands(idx, ValueRange {weight}); setComputeOperandSegmentSizes( getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); - auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc); + auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc); if (!blockArg) return std::nullopt; return std::make_tuple(getOperation()->getOperand(idx), *blockArg); @@ -60,7 +69,7 @@ std::optional> SpatCompute::insertInput(unsigne getOperation()->insertOperands(weightCount + idx, ValueRange {input}); setComputeOperandSegmentSizes( getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); - auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc); + auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc); if (!blockArg) return std::nullopt; return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); @@ -100,28 +109,36 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); } -std::optional SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); } +std::optional SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); } std::optional SpatComputeBatch::getWeightArgument(unsigned idx) { - return getBatchBodyArgument(getBody(), 1 + idx); + return getBlockArgument(getBody(), 1 + idx); } std::optional SpatComputeBatch::getInputArgument(unsigned idx) { - return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx); + return getBlockArgument(getBody(), 1 + getWeights().size() + idx); } std::optional SpatComputeBatch::getOutputArgument(unsigned idx) { - return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); + return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); } std::optional> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { + if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { + auto index = std::distance(getWeights().begin(), existing); + llvm::dbgs() << "Bum bum bum bum\n"; + return { + {*existing, *getWeightArgument(index)} + }; + } + unsigned weightCount = getWeights().size(); unsigned inputCount = getInputs().size(); getOperation()->insertOperands(idx, ValueRange {weight}); setComputeOperandSegmentSizes( getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); - auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc); + auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc); if (!blockArg) return std::nullopt; return std::make_tuple(getOperation()->getOperand(idx), *blockArg); @@ -133,7 +150,7 @@ std::optional> SpatComputeBatch::insertInput(un getOperation()->insertOperands(weightCount + idx, ValueRange {input}); setComputeOperandSegmentSizes( getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); - auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc); + auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc); if (!blockArg) return std::nullopt; return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index b51896f..30759b9 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -5,6 +5,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" +#include #include #include #include @@ -177,20 +178,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu Time bestEst = 0; Time bestEft = 0; Time bestOeft = std::numeric_limits