This commit is contained in:
@@ -5,7 +5,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
|||||||
add_pim_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
ConversionPatterns.cpp
|
ConversionPatterns.cpp
|
||||||
CompileTime.cpp
|
CompileTime.cpp
|
||||||
HostLegality.cpp
|
ONNXToSpatialVerifier.cpp
|
||||||
PrePatterns.cpp
|
PrePatterns.cpp
|
||||||
PostPatterns.cpp
|
PostPatterns.cpp
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
|
|||||||
@@ -7,6 +7,9 @@
|
|||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -195,66 +198,90 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
|
||||||
|
static std::optional<CompileTimeSource>
|
||||||
|
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return false;
|
return std::nullopt;
|
||||||
|
|
||||||
if (!visited.insert(op).second)
|
if (!visited.insert(op).second)
|
||||||
return true;
|
return {
|
||||||
|
{op, chainLength}
|
||||||
|
};
|
||||||
|
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||||
return true;
|
return {
|
||||||
|
{op, chainLength}
|
||||||
|
};
|
||||||
|
|
||||||
|
chainLength += 1;
|
||||||
|
|
||||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||||
return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited);
|
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
|
||||||
|
|
||||||
if (!isStaticTensorResult(op))
|
if (!isStaticTensorResult(op))
|
||||||
return false;
|
return std::nullopt;
|
||||||
|
|
||||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||||
return isCompileTimeOpImpl(transposeOp, visited);
|
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||||
return isCompileTimeOpImpl(collapseShapeOp,visited);
|
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||||
return isCompileTimeOpImpl(expandShapeOp, visited);
|
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited);
|
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||||
return isCompileTimeOpImpl(splatOp, visited);
|
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||||
return isCompileTimeOpImpl(extractRowsOp, visited);
|
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||||
bool res = true;
|
std::optional<CompileTimeSource> res = {};
|
||||||
for (auto operandValue : concatOp.getOperands()) {
|
for (auto operandValue : concatOp.getOperands()) {
|
||||||
res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited);
|
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
|
||||||
if(!res) break;
|
if (!partialRes)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
res = partialRes;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if(res->chainLength < partialRes->chainLength){
|
||||||
|
res = partialRes;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
|
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return getCompileTimeSourceImpl(op, visited);
|
||||||
|
}
|
||||||
|
|
||||||
bool isCompileTimeComputable(Value value) {
|
bool isCompileTimeComputable(Value value) {
|
||||||
auto* definingOp = value.getDefiningOp();
|
auto* definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
return isCompileTimeOpImpl(definingOp, visited);
|
return getCompileTimeSourceImpl(definingOp, visited).has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isCompileTimeOp(Operation* op) {
|
bool isCompileTimeOp(Operation* op) {
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
return isCompileTimeOpImpl(op, visited);
|
return getCompileTimeSourceImpl(op, visited).has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
||||||
|
|||||||
@@ -6,6 +6,13 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct CompileTimeSource {
|
||||||
|
mlir::Operation* source;
|
||||||
|
size_t chainLength;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
|
||||||
|
|
||||||
bool isCompileTimeComputable(mlir::Value value);
|
bool isCompileTimeComputable(mlir::Value value);
|
||||||
|
|
||||||
bool isCompileTimeOp(mlir::Operation* op);
|
bool isCompileTimeOp(mlir::Operation* op);
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
|
||||||
|
|
||||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
|
||||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
|
||||||
continue;
|
|
||||||
if (isCompileTimeOp(&op))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
|
||||||
"spat.compute");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
|
||||||
|
|
||||||
return success(!diagnostics.hasFailure());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -220,7 +220,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "Common/IR/WeightUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
|
||||||
|
auto source = getCompileTimeSource(extractSlice.getOperation());
|
||||||
|
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
|
||||||
|
|
||||||
|
diagnostics.report(extractSlice.getOperation(),
|
||||||
|
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||||
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
|
for (Operation& op : funcOp.getOps()) {
|
||||||
|
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||||
|
continue;
|
||||||
|
if (isCompileTimeOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError(
|
||||||
|
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
checkWeightsDirectlyExtracted(funcOp, diagnostics);
|
||||||
|
|
||||||
|
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||||
|
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+1
-1
@@ -5,6 +5,6 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -8,7 +11,7 @@ namespace onnx_mlir {
|
|||||||
namespace spatial {
|
namespace spatial {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
|
std::optional<BlockArgument> getBlockArgument(Region& body, unsigned argIdx) {
|
||||||
if (body.empty())
|
if (body.empty())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
@@ -18,7 +21,7 @@ std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx)
|
|||||||
return block.getArgument(argIdx);
|
return block.getArgument(argIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||||
if (body.empty())
|
if (body.empty())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return body.insertArgument(argIdx, type, loc);
|
return body.insertArgument(argIdx, type, loc);
|
||||||
@@ -34,21 +37,27 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
||||||
return getBatchBodyArgument(getBody(), idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||||
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
return getBlockArgument(getBody(), getWeights().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||||
|
llvm::dbgs() << "Disse netanyao\n";
|
||||||
|
auto index = std::distance(getWeights().begin(), existing);
|
||||||
|
return {
|
||||||
|
{*existing, *getWeightArgument(index)}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = getInputs().size();
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
|
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||||
@@ -60,7 +69,7 @@ std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigne
|
|||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
|
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
@@ -100,28 +109,36 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s
|
|||||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
|
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
return getBatchBodyArgument(getBody(), 1 + idx);
|
return getBlockArgument(getBody(), 1 + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>>
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||||
|
auto index = std::distance(getWeights().begin(), existing);
|
||||||
|
llvm::dbgs() << "Bum bum bum bum\n";
|
||||||
|
return {
|
||||||
|
{*existing, *getWeightArgument(index)}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = getInputs().size();
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
|
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||||
@@ -133,7 +150,7 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
|||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -177,20 +178,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time bestEst = 0;
|
Time bestEst = 0;
|
||||||
Time bestEft = 0;
|
Time bestEft = 0;
|
||||||
Time bestOeft = std::numeric_limits<Time>::max();
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
|
unsigned int bestOverlapWeight = 0;
|
||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
auto crossbarsAreContainedInProcessor = [&processorCrossbars](mlir::Value nodeCrossbar, size_t processor) {
|
||||||
if (graph.nodes[task].crossbarUsage.size() != 0
|
|
||||||
&& !llvm::all_of(graph.nodes[task].crossbarUsage,
|
|
||||||
[&processorCrossbars, processor](mlir::Value nodeCrossbar) {
|
|
||||||
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
||||||
})
|
};
|
||||||
|
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
auto crossbarsAreContained = std::bind(crossbarsAreContainedInProcessor, std::placeholders::_1, processor);
|
||||||
|
if (graph.nodes[task].crossbarUsage.size() != 0
|
||||||
|
&& !llvm::all_of(graph.nodes[task].crossbarUsage, crossbarsAreContained)
|
||||||
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
|
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
|
||||||
> options.crossbarCapacity) {
|
> options.crossbarCapacity) {
|
||||||
|
|
||||||
crossbarRejected = true;
|
crossbarRejected = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
unsigned int overlapWeight =
|
||||||
|
llvm::count_if(graph.nodes[task].crossbarUsage, crossbarsAreContained);
|
||||||
|
|
||||||
Time dataReady = 0;
|
Time dataReady = 0;
|
||||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
@@ -224,12 +230,19 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||||
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
|
bestOverlapWeight = overlapWeight;
|
||||||
|
}
|
||||||
|
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapWeight < bestOverlapWeight) {
|
||||||
|
bestProcessor = processor;
|
||||||
|
bestEst = est;
|
||||||
|
bestEft = eft;
|
||||||
|
bestOeft = oeft;
|
||||||
|
bestOverlapWeight = overlapWeight;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user