diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 515bff5..4eee65c 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -19,7 +19,6 @@ llvm::cl::opt pimMergeScheduler("pim-merge-scheduler", llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), - llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), llvm::cl::init(MergeSchedulerPeft), llvm::cl::cat(OnnxMlirOptions)); @@ -49,12 +48,6 @@ llvm::cl::opt coresCount("core-count", llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."), llvm::cl::init(-1)); -llvm::cl::opt dcpCriticalWindowSize( - "dcp-critical-window-size", - llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. " - "Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."), - llvm::cl::init(4000)); - llvm::cl::opt ignoreConcatError("ignore-concat-error", llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index a5f1d16..05e51ab 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -22,7 +22,6 @@ typedef enum { typedef enum { MergeSchedulerPeft = 0, - MergeSchedulerDcp = 1, } PimMergeSchedulerType; extern llvm::cl::OptionCategory OnnxMlirOptions; @@ -36,7 +35,6 @@ extern llvm::cl::opt pimEmitJson; extern llvm::cl::opt crossbarSize; extern llvm::cl::opt crossbarCountInCore; extern llvm::cl::opt coresCount; -extern llvm::cl::opt dcpCriticalWindowSize; bool hasExplicitPimCoreCount(); void verifyExplicitPimCoreCount(); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index bccab97..776e445 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -4,7 +4,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen) add_pim_library(OMONNXToSpatial ConversionPatterns.cpp - HostFoldability.cpp + CompileTime.cpp HostLegality.cpp PrePatterns.cpp PostPatterns.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 6f518eb..52a1efb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -6,7 +6,7 @@ #include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" using namespace mlir; @@ -44,7 +44,7 @@ SmallVector sliceTensor( RankedTensorType::get(sliceShape, cast(tensorToSlice.getType()).getElementType()); Value slice; - if (isHostFoldableValue(tensorToSlice)) { + if (isCompileTimeComputable(tensorToSlice)) { slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); } else { @@ -113,7 +113,7 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte return tensor::SplatOp::create(rewriter, loc, type, elementValue); }; - if (isHostFoldableValue(scalarToBroadcast)) + if (isCompileTimeComputable(scalarToBroadcast)) return buildBroadcast(scalarToBroadcast); auto broadcastCompute = diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp similarity index 84% rename from src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp rename to src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp index c42e186..1ae980c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp @@ -8,7 +8,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -145,7 +145,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) { return nullptr; } -static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl& visited) { +static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl& visited) { auto* definingOp = value.getDefiningOp(); if (!definingOp || !visited.insert(definingOp).second) return nullptr; @@ -156,7 +156,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm: return denseAttr; if (auto transposeOp = dyn_cast(definingOp)) { - auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited); + auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited); if (!inputAttr) return nullptr; @@ -169,7 +169,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm: } if (auto collapseShapeOp = dyn_cast(definingOp)) { - auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); + auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(collapseShapeOp.getType())); @@ -177,7 +177,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm: } if (auto expandShapeOp = dyn_cast(definingOp)) { - auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited); + auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(expandShapeOp.getType())); @@ -185,7 +185,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm: } if (auto extractSliceOp = dyn_cast(definingOp)) { - auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited); + auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited); if (!inputAttr) return nullptr; auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp); @@ -195,62 +195,71 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm: return nullptr; } -static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { - if (!op || !visited.insert(op).second) +static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { + if (!op) return false; + if (!visited.insert(op).second) + return true; + if (isa(op)) return true; if (auto extractOp = dyn_cast(op)) - return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor()); + return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited); if (!isStaticTensorResult(op)) return false; if (auto transposeOp = dyn_cast(op)) - return isHostFoldableValue(transposeOp.getData()); + return isCompileTimeOpImpl(transposeOp, visited); if (auto collapseShapeOp = dyn_cast(op)) - return isHostFoldableValue(collapseShapeOp.getSrc()); + return isCompileTimeOpImpl(collapseShapeOp,visited); if (auto expandShapeOp = dyn_cast(op)) - return isHostFoldableValue(expandShapeOp.getSrc()); + return isCompileTimeOpImpl(expandShapeOp, visited); if (auto extractSliceOp = dyn_cast(op)) - return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource()); + return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited); if (auto splatOp = dyn_cast(op)) - return isHostFoldableValue(splatOp.getInput()); + return isCompileTimeOpImpl(splatOp, visited); if (auto extractRowsOp = dyn_cast(op)) - return isHostFoldableValue(extractRowsOp.getInput()); + return isCompileTimeOpImpl(extractRowsOp, visited); - if (auto concatOp = dyn_cast(op)) - return llvm::all_of(concatOp.getInputs(), isHostFoldableValue); + if (auto concatOp = dyn_cast(op)){ + bool res = true; + for(auto operandValue : concatOp.getOperands()){ + res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited); + if(!res) break; + } + return res; + } return false; } } // namespace -bool isHostFoldableValue(Value value) { +bool isCompileTimeComputable(Value value) { auto* definingOp = value.getDefiningOp(); if (!definingOp) return false; llvm::SmallPtrSet visited; - return isHostFoldableOpImpl(definingOp, visited); + return isCompileTimeOpImpl(definingOp, visited); } -bool isHostFoldableOp(Operation* op) { +bool isCompileTimeOp(Operation* op) { llvm::SmallPtrSet visited; - return isHostFoldableOpImpl(op, visited); + return isCompileTimeOpImpl(op, visited); } -DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) { +DenseElementsAttr getHostConstDenseElementsAttr(Value value) { llvm::SmallPtrSet visited; - return getHostFoldableDenseElementsAttrImpl(value, visited); + return getHostConstantDenseElementsAttrImpl(value, visited); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp new file mode 100644 index 0000000..f0574b5 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +namespace onnx_mlir { + +bool isCompileTimeComputable(mlir::Value value); + +bool isCompileTimeOp(mlir::Operation* op); + +mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp deleted file mode 100644 index 3e3437c..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" - -namespace onnx_mlir { - -bool isHostFoldableValue(mlir::Value value); - -bool isHostFoldableOp(mlir::Operation* op); - -mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp index 252631a..9a9c297 100644 --- a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp @@ -3,7 +3,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -17,7 +17,7 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { for (Operation& op : funcOp.getFunctionBody().front()) { if (isa(&op)) continue; - if (isHostFoldableOp(&op)) + if (isCompileTimeOp(&op)) continue; diagnostics.report(&op, [](Operation* illegalOp) { diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 87da1e9..007f019 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -13,7 +13,7 @@ #include "Common/Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" @@ -92,7 +92,7 @@ static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) { for (Operation& op : llvm::make_early_inc_range(entryBlock)) { auto transposeOp = dyn_cast(&op); - if (!transposeOp || isHostFoldableOp(transposeOp)) + if (!transposeOp || isCompileTimeOp(transposeOp)) continue; // Transpose stays globally legal because constant/view-only cases are diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 0a855c4..c1f9a80 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -11,7 +11,7 @@ #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -391,7 +391,7 @@ static Value lowerSingleConvGroup(Value x, const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t wMaxDim = std::max(patchSize, numChannelsOut); const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); - auto wDenseAttr = getHostFoldableDenseElementsAttr(w); + auto wDenseAttr = getHostConstDenseElementsAttr(w); // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] @@ -412,7 +412,7 @@ static Value lowerSingleConvGroup(Value x, DenseElementsAttr biasDenseAttr; if (hasB) { gemmBias = b; - biasDenseAttr = getHostFoldableDenseElementsAttr(b); + biasDenseAttr = getHostConstDenseElementsAttr(b); biasMatrix = expandBiasIfNeeded(b, rewriter, loc); } const bool canPackWeightsAsConstants = static_cast(wDenseAttr); @@ -717,7 +717,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, } Value result; - if (llvm::all_of(groupResults, isHostFoldableValue)) { + if (llvm::all_of(groupResults, isCompileTimeComputable)) { result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); } else { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 20511b3..5284daf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -11,7 +11,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -55,7 +55,7 @@ static Value transposeForSpatial(Value value, ArrayRef permutation, ConversionPatternRewriter& rewriter, Location loc) { - if (isHostFoldableValue(value)) + if (isCompileTimeComputable(value)) return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)); auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { @@ -67,7 +67,7 @@ static Value transposeForSpatial(Value value, static Value expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { - if (isHostFoldableValue(value)) + if (isCompileTimeComputable(value)) return tensor::ExpandShapeOp::create(rewriter, loc, resultType, @@ -121,7 +121,7 @@ static SmallVector materializeBatchRowSlices(Value matrix, auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); SmallVector resultTypes(static_cast(numRows), rowType); - if (isHostFoldableValue(matrix)) { + if (isCompileTimeComputable(matrix)) { auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix); return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 03a972a..7609b91 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -10,7 +10,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -55,7 +55,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); }; - if (isHostFoldableValue(value)) + if (isCompileTimeComputable(value)) return buildCollapsed(value); auto collapseCompute = @@ -114,7 +114,7 @@ static Value extractBatchMatrix(Value value, }); }; - if (isHostFoldableValue(value)) + if (isCompileTimeComputable(value)) return buildMatrix(value); auto batchMatrixCompute = @@ -142,7 +142,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); }; - if (isHostFoldableValue(value)) + if (isCompileTimeComputable(value)) return buildTranspose(value); auto transposeCompute = @@ -182,7 +182,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewr outputShape[axis] = concatDimSize; auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); - if (llvm::all_of(inputs, isHostFoldableValue)) + if (llvm::all_of(inputs, isCompileTimeComputable)) return createSpatConcat(rewriter, loc, axis, inputs); auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { @@ -235,7 +235,7 @@ struct MatMulToGemm : OpRewritePattern { } Location loc = matmulOp.getLoc(); - bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); + bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc); Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index fa96e2e..b2f8381 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -7,7 +7,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -91,7 +91,7 @@ static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewr outputShape[axis] = concatDimSize; auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); - if (llvm::all_of(inputs, isHostFoldableValue)) + if (llvm::all_of(inputs, isCompileTimeComputable)) return createSpatConcat(rewriter, loc, axis, inputs); auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { @@ -135,7 +135,7 @@ static Value squeezeReducedAxes(Value keepdimsValue, } auto reassociation = buildCollapseReassociation(reducedAxes); - if (isHostFoldableValue(keepdimsValue)) + if (isCompileTimeComputable(keepdimsValue)) return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); auto squeezeCompute = diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index 4d616e8..86a96a6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -3,7 +3,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern { auto inputs = adaptor.getInputs(); int64_t axis = adaptor.getAxis(); - if (llvm::all_of(inputs, isHostFoldableValue)) { + if (llvm::all_of(inputs, isCompileTimeComputable)) { rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index f44c524..4c1131f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -5,7 +5,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -115,7 +115,7 @@ struct Reshape : OpConversionPattern { } auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { - if (isHostFoldableValue(adaptor.getData())) { + if (isCompileTimeComputable(adaptor.getData())) { rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData())); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index 9a3e6b2..8683481 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -3,7 +3,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -61,7 +61,7 @@ struct Split : OpConversionPattern { sliceSizes.push_back(resultType.getShape()[axis]); } - if (isHostFoldableValue(adaptor.getInput())) { + if (isCompileTimeComputable(adaptor.getInput())) { for (int64_t sliceSize : sliceSizes) { outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); offset += sliceSize; diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index cf2b5ba..f843879 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -11,14 +11,8 @@ add_pim_library(SpatialOps Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp - Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp - Transforms/MergeComputeNodes/DCPGraph/Graph.cpp - Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp - Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp - Transforms/MergeComputeNodes/DCPGraph/Task.cpp - Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 2fbf0e2..38646c4 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -15,7 +15,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -120,7 +120,7 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { template static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) { for (Value weight : computeOp.getWeights()) { - if (!isHostFoldableValue(weight)) + if (!isCompileTimeComputable(weight)) return computeOp.emitOpError() << kind << " weights must be statically computed from constants"; } return success(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp deleted file mode 100644 index cfc03d3..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "../Scheduling/ComputeGraph.hpp" -#include "../Scheduling/DcpScheduler.hpp" -#include "DCPAnalysis.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" - -namespace onnx_mlir { -namespace spatial { - -DCPAnalysisResult DCPAnalysis::run() { - ComputeGraph graph = buildComputeGraph(entryOp); - DcpScheduleOptions options; - if (coresCount.getValue() > 0) - options.processorCount = static_cast(coresCount.getValue()); - options.criticalWindowSize = dcpCriticalWindowSize.getValue(); - options.allowFallbackForAutoCoreCount = true; - return runDcpScheduler(graph, options, entryOp->getContext()); -} - -} // namespace spatial -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp deleted file mode 100644 index eec2d6b..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include "mlir/IR/Operation.h" - -#include "../Scheduling/MergeSchedule.hpp" - -namespace onnx_mlir { -namespace spatial { -using DCPAnalysisResult = MergeScheduleResult; - -struct DCPAnalysis { -private: - DCPAnalysisResult result; - mlir::Operation* entryOp; - DCPAnalysisResult run(); - -public: - DCPAnalysis(mlir::Operation* op) - : entryOp(op) { - result = run(); - } - DCPAnalysisResult& getResult() { return result; } -}; - -} // namespace spatial -} // namespace onnx_mlir - -using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp deleted file mode 100644 index 3bca16c..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp +++ /dev/null @@ -1,1622 +0,0 @@ -//===----------------------------------------------------------------------===// -// DCP-inspired task scheduler. -// -// Input: a DAG of compute tasks. Each task has an execution weight; each edge -// carries an inter-task transfer cost that only applies when producer and -// consumer land on different CPUs. -// -// Output: an assignment of every task to a CPU and an order within that CPU, -// aiming to minimize the overall critical-path length (DCPL). -// -// Every task keeps two timing estimates: -// AEST - earliest start time, driven by parent completions + transfers. -// ALST - latest start time that still keeps the task on the critical path. -// A task is "critical" when its slack (ALST - AEST) is zero. -// -// Main loop (runDcp): -// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG. -// 2. While there are ready tasks (all dependency parents scheduled): -// a. Pick the candidate with the tightest slack (earliest AEST breaks ties). -// b. selectProcessor() tries every candidate CPU and picks the one that -// minimizes a composite cost (own slot + the smallest unscheduled child). -// c. Commit the placement and refresh AEST/ALST. -// d. Release any child whose dependency parents are now all scheduled. -// -// Heuristic notes: classic DCP assumes identical task costs on every CPU and -// a single-issue processor model. We diverge - crossbar capacity can make a -// task infeasible on a CPU, and placement happens incrementally. That makes -// this a heuristic rather than a faithful DCP implementation. -// -// Parallelism: selectProcessor's per-CPU findSlot sweep is read-only, so we -// run it concurrently across CPUs via mlir::parallelFor. The subsequent -// sequential evaluation benefits from ordering CPUs by ascending slot.aest, -// which tightens the bestComposite early-prune bound. -//===----------------------------------------------------------------------===// - -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Threading.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/ErrorHandling.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "DCPAnalysis.hpp" -#include "Graph.hpp" -#include "GraphDebug.hpp" -#include "GraphSupport.hpp" -#include "Task.hpp" -#include "UniqueWorklist.hpp" -#include "Utils.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" - -#ifdef DCP_DEBUG_ENABLED -namespace { -// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set. -struct SelectTimers { - double findSlot = 0.0; - double dedup = 0.0; - double precheck = 0.0; - double snapshotInsertUpdate = 0.0; - double childSlot = 0.0; - double rollbackRestore = 0.0; - long iterations = 0; - long passedPrecheck = 0; - long passedDcpl = 0; - long tasksProcessed = 0; - void dump(const char* label) const { - std::fprintf(stderr, - "[selectProfile:%s] tasks=%ld dedup=%.2fs findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs " - "childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n", - label, - tasksProcessed, - dedup, - findSlot, - precheck, - snapshotInsertUpdate, - childSlot, - rollbackRestore, - iterations, - passedPrecheck, - passedDcpl); - } - ~SelectTimers() { - if (std::getenv("DCP_SELECT_PROFILE")) - dump("exit"); - } -}; -static SelectTimers gSelectTimers; -} // namespace -#endif - -namespace { - -uint64_t mixHash(uint64_t seed, uint64_t value) { - seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); - return seed; -} - -uint64_t finishHash(uint64_t seed) { - seed ^= seed >> 33; - seed *= 0xff51afd7ed558ccdULL; - seed ^= seed >> 33; - seed *= 0xc4ceb9fe1a85ec53ULL; - seed ^= seed >> 33; - return seed; -} - -uint64_t hashEdgeSignature(uint64_t neighborHash, Weight weight, uint64_t direction) { - uint64_t hash = mixHash(0x84222325cbf29ce4ULL, direction); - hash = mixHash(hash, neighborHash); - hash = mixHash(hash, static_cast(weight)); - return finishHash(hash); -} - -struct CpuAestCache { - Time defaultAest = 0; - llvm::SmallDenseMap colocatedParentAests; - - Time get(CPU cpu) const { - auto it = colocatedParentAests.find(cpu); - if (it == colocatedParentAests.end()) - return defaultAest; - return it->second; - } -}; - -struct CpuTimeMax { - CPU cpu = -1; - Time time = 0; -}; - -void updateCpuTimeMax(CpuTimeMax& first, CpuTimeMax& second, CPU cpu, Time time) { - if (first.cpu == cpu) { - first.time = std::max(first.time, time); - return; - } - if (second.cpu == cpu) { - second.time = std::max(second.time, time); - if (second.time > first.time) - std::swap(first, second); - return; - } - if (time >= first.time) { - second = first; - first = {cpu, time}; - return; - } - if (time > second.time) - second = {cpu, time}; -} - -CpuAestCache computeCpuAestCache(TaskDCP* task) { - CpuAestCache cache; - llvm::SmallDenseMap transferAestByCpu; - llvm::SmallDenseMap localAestByCpu; - Time unscheduledTransferAest = 0; - - for (const Edge& parentEdge : task->parents) { - Time parentFinish = addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight()); - Time transferAest = addOrMax(parentFinish, getTransferCost(parentEdge.first, task)); - if (std::optional parentCpu = parentEdge.first->getCpu()) { - Time& cpuTransferAest = transferAestByCpu[*parentCpu]; - cpuTransferAest = std::max(cpuTransferAest, transferAest); - Time& cpuLocalAest = localAestByCpu[*parentCpu]; - cpuLocalAest = std::max(cpuLocalAest, parentFinish); - continue; - } - unscheduledTransferAest = std::max(unscheduledTransferAest, transferAest); - } - - CpuTimeMax firstOther {-1, unscheduledTransferAest}; - CpuTimeMax secondOther {-1, 0}; - for (const auto& entry : transferAestByCpu) - updateCpuTimeMax(firstOther, secondOther, entry.first, entry.second); - - cache.defaultAest = firstOther.time; - for (const auto& entry : localAestByCpu) { - CPU cpu = entry.first; - Time bestNonLocalParentAest = firstOther.cpu == cpu ? secondOther.time : firstOther.time; - cache.colocatedParentAests[cpu] = std::max(bestNonLocalParentAest, entry.second); - } - return cache; -} - -} // namespace - -//===----------------------------------------------------------------------===// -// Edge manipulation -//===----------------------------------------------------------------------===// - -std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling) { - auto oldChild = parent->addChild(child, weight, isScheduling); - auto oldParent = child->addParent(parent, weight, isScheduling); - assert(oldChild.has_value() == oldParent.has_value() && "The edge must be present in both element"); - if (oldChild.has_value()) { - return { - {*oldParent, *oldChild} - }; - } - return {}; -} - -void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling) { - parent->removeChild(child, isScheduling); - child->removeParent(parent, isScheduling); -} - -// A dependency edge may appear multiple times (e.g. from separate data inputs); -// the transfer cost is the maximum across those parallel edges. Cost is zero -// when both endpoints share a CPU. -Weight getTransferCost(TaskDCP* parent, TaskDCP* child) { - if (parent->scheduledCpu.has_value() && child->scheduledCpu.has_value() - && *parent->scheduledCpu == *child->scheduledCpu) - return 0; - Weight maxTransferCost = 0; - bool foundTransferCost = false; - for (const auto& edge : parent->children) - if (edge.first == child && !edge.isScheduling) { - maxTransferCost = std::max(maxTransferCost, edge.second); - foundTransferCost = true; - } - assert(foundTransferCost && "missing transfer cost for dependency edge"); - return maxTransferCost; -} - -//===----------------------------------------------------------------------===// -// Indexing and CPU task lists -//===----------------------------------------------------------------------===// - -size_t GraphDCP::getNodeIndex(const TaskDCP* task) const { - assert(task >= nodes.data() && task < nodes.data() + nodes.size() && "task must belong to graph"); - return static_cast(task - nodes.data()); -} - -GraphDCP::CpuTaskList& GraphDCP::getOrCreateCpuTasks(CPU cpu) { - assert(cpu >= 0 && "cpu id must be non-negative"); - size_t cpuIndex = static_cast(cpu); - if (cpuTasks.size() <= cpuIndex) - cpuTasks.resize(cpuIndex + 1); - return cpuTasks[cpuIndex]; -} - -const GraphDCP::CpuTaskList* GraphDCP::findCpuTasks(CPU cpu) const { - if (cpu < 0) - return nullptr; - size_t cpuIndex = static_cast(cpu); - if (cpuIndex >= cpuTasks.size()) - return nullptr; - return &cpuTasks[cpuIndex]; -} - -std::vector GraphDCP::getRoots() { - std::vector tmp; - for (auto& task : nodes) - if (!task.hasParents()) - tmp.push_back(&task); - return tmp; -} - -void GraphDCP::initTaskStructureHashes() { - taskStructureHashes.resize(nodes.size()); - for (auto [index, task] : llvm::enumerate(nodes)) { - uint64_t hash = mixHash(0x7442b1129fd01363ULL, static_cast(task.getWeight())); - hash = mixHash(hash, static_cast(task.getCrossbarUsage())); - taskStructureHashes[index] = finishHash(hash); - } - - std::vector nextHashes(nodes.size()); - std::vector edgeHashes; - for (int iteration = 0; iteration < 4; ++iteration) { - for (auto [index, task] : llvm::enumerate(nodes)) { - uint64_t hash = mixHash(0x464dcab27ac82291ULL, taskStructureHashes[index]); - edgeHashes.clear(); - edgeHashes.reserve(task.parents.size() + task.children.size()); - for (const Edge& parent : task.parents) - if (!parent.isScheduling) - edgeHashes.push_back( - hashEdgeSignature(taskStructureHashes[getNodeIndex(parent.first)], parent.second, /*direction=*/0)); - for (const Edge& child : task.children) - if (!child.isScheduling) - edgeHashes.push_back( - hashEdgeSignature(taskStructureHashes[getNodeIndex(child.first)], child.second, /*direction=*/1)); - llvm::sort(edgeHashes); - hash = mixHash(hash, static_cast(edgeHashes.size())); - for (uint64_t edgeHash : edgeHashes) - hash = mixHash(hash, edgeHash); - nextHashes[index] = finishHash(hash); - } - taskStructureHashes.swap(nextHashes); - } -} - -// Compact dedup key for CPU `c` vs `candidate`: mixes candidateAest, crossbar -// usage, and the incremental cpu structure hash. No heap allocation. -uint64_t GraphDCP::computeCpuCandidateKey(Time candidateAest, CPU cpu) { - uint64_t hash = mixHash(0xd6e8feb86659fd93ULL, static_cast(candidateAest)); - hash = mixHash(hash, static_cast(getCpuCrossbarUsage(cpu))); - auto it = cpuStructureHashes.find(cpu); - hash = mixHash(hash, it != cpuStructureHashes.end() ? it->second : 0ULL); - return finishHash(hash); -} - -// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the -// neighbouring tasks and keeping the global topological order consistent. -TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) { - TaskInsertion ret; - Weight scheduledWeight = task->computeWeightOnCpu(this, cpu); - task->setCpu(cpu); - task->setWeight(scheduledWeight); - reserveTaskCrossbars(cpu, task); - cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)]; - auto& tasksInCpu = getOrCreateCpuTasks(cpu); - unsigned int numCpuTasks = tasksInCpu.size(); - assert(position <= numCpuTasks && "Inserting in a not valid position"); - auto insertedPoint = tasksInCpu.insert(std::next(tasksInCpu.begin(), position), task); - ret.cpuModified = cpu; - ret.taskInserted = task; - ret.graph = this; - - // If we split an existing neighbour-neighbour scheduling edge, drop it; the - // two new edges below recreate the ordering with `task` in between. - if (insertedPoint != tasksInCpu.begin() && std::next(insertedPoint) != tasksInCpu.end()) { - auto precedentPoint = std::prev(insertedPoint, 1); - auto nextPoint = std::next(insertedPoint, 1); - removeEdge(*precedentPoint, *nextPoint, true); - } - - if (insertedPoint != tasksInCpu.begin()) { - auto precedentPoint = std::prev(insertedPoint, 1); - auto oldEdge = addEdge(*precedentPoint, *insertedPoint, 0, true); - ret.beforeNode = oldEdge; - - if (*task < **precedentPoint) - topologicalMoveAfter(task, *precedentPoint, &ret); - } - - if (std::next(insertedPoint) != tasksInCpu.end()) { - auto nextPoint = std::next(insertedPoint, 1); - auto oldEdge = addEdge(*insertedPoint, *nextPoint, 0, true); - ret.afterNode = oldEdge; - if (**nextPoint < *task) - topologicalMoveBefore(task, *nextPoint, &ret); - } - return ret; -} - -void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) { - releaseTaskCrossbars(cpu, task); - cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)]; - task->resetCpu(); - task->resetWeight(); - auto& scheduledTasks = getOrCreateCpuTasks(cpu); - auto taskPosition = std::find(scheduledTasks.begin(), scheduledTasks.end(), task); - assert(taskPosition != scheduledTasks.end() && "Removing a not present task"); - TaskDCP* previousTask = nullptr; - TaskDCP* nextTask = nullptr; - if (taskPosition != scheduledTasks.begin()) { - auto previousPoint = std::prev(taskPosition, 1); - previousTask = *previousPoint; - removeEdge(*previousPoint, *taskPosition, true); - } - - if (std::next(taskPosition) != scheduledTasks.end()) { - auto nextPoint = std::next(taskPosition, 1); - nextTask = *nextPoint; - removeEdge(*taskPosition, *nextPoint, true); - } - if (previousTask != nullptr && nextTask != nullptr) - addEdge(previousTask, nextTask, 0, true); - scheduledTasks.erase(taskPosition); -} - -//===----------------------------------------------------------------------===// -// Crossbar capacity bookkeeping -//===----------------------------------------------------------------------===// - -CrossbarUsage GraphDCP::getCpuCrossbarUsage(CPU cpu) const { - auto it = cpuCrossbarUsage.find(cpu); - if (it == cpuCrossbarUsage.end()) - return 0; - return it->second; -} - -CrossbarUsage GraphDCP::getCpuCrossbarCapacity() const { - assert(onnx_mlir::crossbarSize.getValue() > 0 && "crossbar-size must be strictly positive"); - assert(onnx_mlir::crossbarCountInCore.getValue() > 0 && "crossbar-count must be strictly positive"); - CrossbarUsage crossbarEdge = static_cast(onnx_mlir::crossbarSize.getValue()); - CrossbarUsage crossbarArea = checkedMultiply(crossbarEdge, crossbarEdge); - return checkedMultiply(static_cast(onnx_mlir::crossbarCountInCore.getValue()), crossbarArea); -} - -CrossbarUsage GraphDCP::getTaskCrossbarFootprint(const TaskDCP* task) const { - CrossbarUsage crossbarCount = task->getCrossbarUsage(); - if (crossbarCount == 0) - return 0; - CrossbarUsage crossbarEdge = static_cast(onnx_mlir::crossbarSize.getValue()); - CrossbarUsage crossbarArea = checkedMultiply(crossbarEdge, crossbarEdge); - return checkedMultiply(crossbarCount, crossbarArea); -} - -void GraphDCP::reserveTaskCrossbars(CPU cpu, const TaskDCP* task) { - cpuCrossbarUsage[cpu] = checkedAdd(getCpuCrossbarUsage(cpu), getTaskCrossbarFootprint(task)); -} - -void GraphDCP::releaseTaskCrossbars(CPU cpu, const TaskDCP* task) { - CrossbarUsage footprint = getTaskCrossbarFootprint(task); - CrossbarUsage currentUsage = getCpuCrossbarUsage(cpu); - assert(currentUsage >= footprint && "crossbar usage underflow"); - cpuCrossbarUsage[cpu] = currentUsage - footprint; -} - -bool GraphDCP::wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const { - CrossbarUsage footprint = getTaskCrossbarFootprint(task); - if (footprint == 0) - return false; - CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(cpu), footprint); - return nextUsage >= getCpuCrossbarCapacity(); -} - -size_t GraphDCP::crossbarsUsed() const { - CrossbarUsage crossbarEdge = static_cast(onnx_mlir::crossbarSize.getValue()); - CrossbarUsage crossbarArea = crossbarEdge * crossbarEdge; - if (crossbarArea == 0) - return 0; - CrossbarUsage totalArea = 0; - for (const auto& [cpu, usage] : cpuCrossbarUsage) - totalArea = checkedAdd(totalArea, usage); - return static_cast(totalArea / crossbarArea); -} - -size_t GraphDCP::crossbarsAvailable() const { - return static_cast(lastCpu) * onnx_mlir::crossbarCountInCore.getValue(); -} - -//===----------------------------------------------------------------------===// -// AEST / ALST computation -//===----------------------------------------------------------------------===// - -// Walks the topological order once and fills AEST from parent completions, -// while tracking the top two completion times so DCPL updates can avoid a -// second pass. secondMaxCompletion lets us invalidate `maxCompletionTask` -// locally when its AEST moves. -void GraphDCP::initAest() { - auto& worklist = topologicalOrder; - Time maxDcpl = 0; - Time secondMaxCompletionCandidate = 0; - TaskDCP* maxCompletionTaskCandidate = nullptr; - for (auto& task : worklist) { - Time maxParentAest = 0; - for (Edge parentEdge : task.parents) { - maxParentAest = std::max(addOrMax(addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight()), - getTransferCost(parentEdge.first, &task)), - maxParentAest); - } - task.setAest(maxParentAest); - Time completion = addOrMax(maxParentAest, task.getWeight()); - if (completion >= maxDcpl) { - secondMaxCompletionCandidate = maxDcpl; - maxDcpl = completion; - maxCompletionTaskCandidate = &task; - } - else if (completion > secondMaxCompletionCandidate) { - secondMaxCompletionCandidate = completion; - } - } - dcpl = maxDcpl; - maxCompletion = maxDcpl; - secondMaxCompletion = secondMaxCompletionCandidate; - maxCompletionTask = maxCompletionTaskCandidate; -} - -// Same backward pass as initAest but over the reverse topological order, -// seeding ALST from scheduleDcpl on leaves. -void GraphDCP::initAlst() { - Time scheduleDcpl = getDcpl(); - auto& worklist = topologicalOrder; - - for (TaskDCP& node : llvm::reverse(worklist)) { - Time minAlst = std::numeric_limits