1 Commits

Author SHA1 Message Date
NiccoloN 87922d994f multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 1h2m3s
2026-04-22 18:29:06 +02:00
9 changed files with 79 additions and 123 deletions
@@ -55,24 +55,16 @@ pub trait HasSigm {
impl HasSigm for f32 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
impl HasSigm for f64 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
pub trait HasExp {
-48
View File
@@ -4,7 +4,6 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
@@ -97,53 +96,6 @@ void markWeightAlways(Operation* op) {
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](Operation* op) {
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
}
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
});
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
-5
View File
@@ -7,7 +7,6 @@
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@@ -41,10 +40,6 @@ bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
@@ -392,9 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
return isSpatialMvmVmmWeightUse(use);
});
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp);
});
@@ -289,7 +289,8 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
@@ -325,7 +326,8 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = gemmRowArgs.size() == 1
Value packedOutput =
gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
@@ -503,13 +505,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
// the row it needs instead of receiving a full packed tensor and slicing it locally.
auto gemmInputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
@@ -94,8 +94,10 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](PimCoreOp coreOp) {
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
Value weight = weightUse.get();
auto annotateWeight = [&](unsigned weightIndex) {
if (weightIndex >= coreOp.getWeights().size())
return;
Value weight = coreOp.getWeights()[weightIndex];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return;
@@ -103,7 +105,10 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp);
});
};
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
});
}
+10 -3
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -13,7 +14,10 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
@@ -115,10 +119,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
@@ -28,7 +28,7 @@ using namespace mlir;
namespace {
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
llvm::SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
@@ -50,7 +50,7 @@ struct WindowScheduleResult {
bool usedAllAvailableCpus = false;
};
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
@@ -74,7 +74,8 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(spatComputes.size());
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
@@ -173,7 +174,7 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
return selected;
}
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
std::vector<size_t> signature;
for (size_t nodeIndex : selectedNodes) {
const VirtualNode& node = graph.nodes[nodeIndex];
@@ -196,7 +197,8 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
WindowScheduleResult
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
@@ -232,7 +234,9 @@ WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t>
return result;
}
bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
bool coarsenGraph(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
if (mergeGroup.size() < 2)
@@ -299,7 +303,7 @@ bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> merge
}
bool coarsenGraphWithFallback(const VirtualGraph& graph,
ArrayRef<std::vector<size_t>> mergeGroups,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
return true;
@@ -326,7 +330,7 @@ bool coarsenGraphWithFallback(const VirtualGraph& graph,
return !acceptedMergeGroups.empty();
}
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.resize(computeCount);
graph.edges = aggregateEdges(edges);
@@ -340,8 +344,8 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRe
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
ArrayRef<SpatCompute> spatComputes,
ArrayRef<IndexedEdge> originalEdges) {
llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
@@ -363,7 +367,9 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
return result;
}
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) {
GraphDCP graphDCP(spatComputes, edges);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
@@ -377,12 +383,12 @@ DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<Inde
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = dyn_cast<SpatCompute>(op))
if (auto res = llvm::dyn_cast<SpatCompute>(op))
return res;
return {};
}
+2
View File
@@ -1,3 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
add_custom_target(pim-unittest)
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")