Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cff929a083 | |||
| 89b3501aa8 | |||
| 412ca957f6 |
@@ -55,15 +55,23 @@ pub trait HasSigm {
|
|||||||
|
|
||||||
impl HasSigm for f32 {
|
impl HasSigm for f32 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
let ex = self.exp();
|
if self >= 0.0 {
|
||||||
ex / (1.0 + ex)
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
|
let ex = self.exp();
|
||||||
|
ex / (1.0 + ex)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HasSigm for f64 {
|
impl HasSigm for f64 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
let ex = self.exp();
|
if self >= 0.0 {
|
||||||
ex / (1.0 + ex)
|
1.0 / (1.0 + (-self).exp())
|
||||||
|
} else {
|
||||||
|
let ex = self.exp();
|
||||||
|
ex / (1.0 + ex)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
@@ -96,6 +97,53 @@ void markWeightAlways(Operation* op) {
|
|||||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
|
bool found = false;
|
||||||
|
parentOp.walk([&](Operation* op) {
|
||||||
|
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
|
||||||
|
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||||
|
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
|
||||||
|
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||||
|
});
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
|
||||||
|
auto weights = parentOp.getWeights();
|
||||||
|
llvm::SmallSet<unsigned, 8> visited;
|
||||||
|
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||||
|
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||||
|
callback(parentOp->getOpOperand(weightIndex));
|
||||||
|
};
|
||||||
|
|
||||||
|
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||||
|
Operation* user = use.getOwner();
|
||||||
|
unsigned operandIndex = use.getOperandNumber();
|
||||||
|
|
||||||
|
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
|
||||||
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||||
|
assert(root && "expected valid root op");
|
||||||
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!moduleOp || !getGlobalOp)
|
if (!moduleOp || !getGlobalOp)
|
||||||
return {};
|
return {};
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
@@ -40,6 +41,10 @@ bool hasWeightAlways(mlir::Operation* op);
|
|||||||
|
|
||||||
void markWeightAlways(mlir::Operation* op);
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation*>
|
llvm::FailureOr<mlir::Operation*>
|
||||||
|
|||||||
@@ -392,7 +392,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
|
||||||
|
return isSpatialMvmVmmWeightUse(use);
|
||||||
|
});
|
||||||
if (isAlwaysWeight)
|
if (isAlwaysWeight)
|
||||||
markWeightAlways(constantOp);
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -289,8 +289,7 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
|
|||||||
rowResults.reserve(packedNumRows);
|
rowResults.reserve(packedNumRows);
|
||||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
|
||||||
rewriter.getIndexAttr(packFactor * patchSize)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
rowResults.push_back(
|
rowResults.push_back(
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||||
@@ -326,10 +325,9 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
else {
|
else {
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
Value packedOutput =
|
Value packedOutput = gemmRowArgs.size() == 1
|
||||||
gemmRowArgs.size() == 1
|
? gemmRowArgs.front()
|
||||||
? gemmRowArgs.front()
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
expandedType,
|
expandedType,
|
||||||
@@ -505,38 +503,41 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||||
//
|
//
|
||||||
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
|
// We want to process N pixels at the same time. Instead of doing N separate operations
|
||||||
// the row it needs instead of receiving a full packed tensor and slicing it locally.
|
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||||
auto gemmInputRowType =
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||||
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||||
|
// B_packed: [N * patchSize, N * cOut]
|
||||||
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||||
|
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||||
auto gemmOutputRowType =
|
auto gemmOutputRowType =
|
||||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||||
xType,
|
xType,
|
||||||
im2colType,
|
im2colType,
|
||||||
rowType,
|
rowType,
|
||||||
gemmInputRowType,
|
gemmInputRowType,
|
||||||
batchSize,
|
batchSize,
|
||||||
numChannelsIn,
|
numChannelsIn,
|
||||||
xHeight,
|
xHeight,
|
||||||
xWidth,
|
xWidth,
|
||||||
wHeight,
|
wHeight,
|
||||||
wWidth,
|
wWidth,
|
||||||
padHeightBegin,
|
padHeightBegin,
|
||||||
padHeightEnd,
|
padHeightEnd,
|
||||||
padWidthBegin,
|
padWidthBegin,
|
||||||
padWidthEnd,
|
padWidthEnd,
|
||||||
strideHeight,
|
strideHeight,
|
||||||
strideWidth,
|
strideWidth,
|
||||||
dilationHeight,
|
dilationHeight,
|
||||||
dilationWidth,
|
dilationWidth,
|
||||||
outWidth,
|
outWidth,
|
||||||
patchSize,
|
patchSize,
|
||||||
numPatches,
|
numPatches,
|
||||||
numPatchesPerBatch,
|
numPatchesPerBatch,
|
||||||
effectiveMaxParallelPixels,
|
effectiveMaxParallelPixels,
|
||||||
rewriter,
|
rewriter,
|
||||||
loc);
|
loc);
|
||||||
|
|
||||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||||
wTrans,
|
wTrans,
|
||||||
|
|||||||
@@ -94,10 +94,8 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](PimCoreOp coreOp) {
|
funcOp.walk([&](PimCoreOp coreOp) {
|
||||||
auto annotateWeight = [&](unsigned weightIndex) {
|
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
|
||||||
if (weightIndex >= coreOp.getWeights().size())
|
Value weight = weightUse.get();
|
||||||
return;
|
|
||||||
Value weight = coreOp.getWeights()[weightIndex];
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp)
|
if (!getGlobalOp)
|
||||||
return;
|
return;
|
||||||
@@ -105,10 +103,7 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
markWeightAlways(getGlobalOp);
|
markWeightAlways(getGlobalOp);
|
||||||
markWeightAlways(globalMemrefOp);
|
markWeightAlways(globalMemrefOp);
|
||||||
};
|
});
|
||||||
|
|
||||||
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
|
||||||
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
@@ -14,10 +13,7 @@
|
|||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||||
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
|
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
|
||||||
if (wcomputeOp)
|
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
|
||||||
|
|
||||||
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
|
||||||
|
|
||||||
if (coreOp)
|
|
||||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ using namespace mlir;
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct VirtualNode {
|
struct VirtualNode {
|
||||||
llvm::SmallVector<size_t, 4> originalComputeIndices;
|
SmallVector<size_t, 4> originalComputeIndices;
|
||||||
Weight weight = 0;
|
Weight weight = 0;
|
||||||
CrossbarUsage crossbarUsage = 0;
|
CrossbarUsage crossbarUsage = 0;
|
||||||
};
|
};
|
||||||
@@ -50,7 +50,7 @@ struct WindowScheduleResult {
|
|||||||
bool usedAllAvailableCpus = false;
|
bool usedAllAvailableCpus = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||||
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
for (auto [start, end, weight] : edges) {
|
for (auto [start, end, weight] : edges) {
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
@@ -74,8 +74,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
|||||||
return aggregatedEdges;
|
return aggregatedEdges;
|
||||||
}
|
}
|
||||||
|
|
||||||
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
|
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
|
||||||
llvm::ArrayRef<IndexedEdge> edges) {
|
|
||||||
VirtualGraph graph;
|
VirtualGraph graph;
|
||||||
graph.nodes.reserve(spatComputes.size());
|
graph.nodes.reserve(spatComputes.size());
|
||||||
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
@@ -174,7 +173,7 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
|
|||||||
return selected;
|
return selected;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
|
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
|
||||||
std::vector<size_t> signature;
|
std::vector<size_t> signature;
|
||||||
for (size_t nodeIndex : selectedNodes) {
|
for (size_t nodeIndex : selectedNodes) {
|
||||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||||
@@ -197,8 +196,7 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
|
|||||||
return aggregateEdges(windowEdges);
|
return aggregateEdges(windowEdges);
|
||||||
}
|
}
|
||||||
|
|
||||||
WindowScheduleResult
|
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||||
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
|
||||||
std::vector<Weight> windowWeights;
|
std::vector<Weight> windowWeights;
|
||||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||||
@@ -234,9 +232,7 @@ scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool coarsenGraph(const VirtualGraph& graph,
|
bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
|
||||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
|
||||||
VirtualGraph& coarsenedGraph) {
|
|
||||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
||||||
if (mergeGroup.size() < 2)
|
if (mergeGroup.size() < 2)
|
||||||
@@ -303,7 +299,7 @@ bool coarsenGraph(const VirtualGraph& graph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
VirtualGraph& coarsenedGraph) {
|
VirtualGraph& coarsenedGraph) {
|
||||||
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
||||||
return true;
|
return true;
|
||||||
@@ -330,7 +326,7 @@ bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
|||||||
return !acceptedMergeGroups.empty();
|
return !acceptedMergeGroups.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
|
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
|
||||||
VirtualGraph graph;
|
VirtualGraph graph;
|
||||||
graph.nodes.resize(computeCount);
|
graph.nodes.resize(computeCount);
|
||||||
graph.edges = aggregateEdges(edges);
|
graph.edges = aggregateEdges(edges);
|
||||||
@@ -344,8 +340,8 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
|
|||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||||
llvm::ArrayRef<SpatCompute> spatComputes,
|
ArrayRef<SpatCompute> spatComputes,
|
||||||
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
ArrayRef<IndexedEdge> originalEdges) {
|
||||||
DCPAnalysisResult result;
|
DCPAnalysisResult result;
|
||||||
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
||||||
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||||
@@ -367,9 +363,7 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
|
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||||
llvm::ArrayRef<IndexedEdge> edges,
|
|
||||||
MLIRContext* context) {
|
|
||||||
GraphDCP graphDCP(spatComputes, edges);
|
GraphDCP graphDCP(spatComputes, edges);
|
||||||
if (coresCount.getValue() > 0)
|
if (coresCount.getValue() > 0)
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
@@ -383,12 +377,12 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
|
|||||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
op = extract.getSource().getDefiningOp();
|
op = extract.getSource().getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
if (auto res = llvm::dyn_cast<SpatCompute>(op))
|
if (auto res = dyn_cast<SpatCompute>(op))
|
||||||
return res;
|
return res;
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
add_custom_target(pim-unittest)
|
add_custom_target(pim-unittest)
|
||||||
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user