1 Commits

Author SHA1 Message Date
NiccoloN 412ca957f6 multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 22m38s
2026-04-23 09:28:57 +02:00
3 changed files with 52 additions and 64 deletions
@@ -289,8 +289,7 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
rowResults.reserve(packedNumRows); rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) { for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back( rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides)); tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
@@ -326,8 +325,7 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
else { else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = Value packedOutput = gemmRowArgs.size() == 1
gemmRowArgs.size() == 1
? gemmRowArgs.front() ? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult(); : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
@@ -505,10 +503,13 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better. // and optionally repack several old rows into one GEMM row to use the available crossbar size better.
// //
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only // We want to process N pixels at the same time. Instead of doing N separate operations
// the row it needs instead of receiving a full packed tensor and slicing it locally. // of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
auto gemmInputRowType = // containing N copies of W^T and concatenate N im2col rows into one longer row:
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType); // A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowType = auto gemmOutputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x, SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
+3 -10
View File
@@ -1,5 +1,4 @@
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@@ -14,10 +13,7 @@
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
} }
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()); if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
if (wcomputeOp) return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()); if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape(); return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure(); return failure();
@@ -28,7 +28,7 @@ using namespace mlir;
namespace { namespace {
struct VirtualNode { struct VirtualNode {
llvm::SmallVector<size_t, 4> originalComputeIndices; SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0; Weight weight = 0;
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
}; };
@@ -50,7 +50,7 @@ struct WindowScheduleResult {
bool usedAllAvailableCpus = false; bool usedAllAvailableCpus = false;
}; };
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) { std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
std::map<std::pair<size_t, size_t>, Weight> edgeWeights; std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) { for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start); size_t startIndex = static_cast<size_t>(start);
@@ -74,8 +74,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
return aggregatedEdges; return aggregatedEdges;
} }
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes, VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph; VirtualGraph graph;
graph.nodes.reserve(spatComputes.size()); graph.nodes.reserve(spatComputes.size());
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) { for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
@@ -174,7 +173,7 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
return selected; return selected;
} }
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) { std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
std::vector<size_t> signature; std::vector<size_t> signature;
for (size_t nodeIndex : selectedNodes) { for (size_t nodeIndex : selectedNodes) {
const VirtualNode& node = graph.nodes[nodeIndex]; const VirtualNode& node = graph.nodes[nodeIndex];
@@ -197,8 +196,7 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
return aggregateEdges(windowEdges); return aggregateEdges(windowEdges);
} }
WindowScheduleResult WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights; std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage; std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1); std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
@@ -234,9 +232,7 @@ scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes,
return result; return result;
} }
bool coarsenGraph(const VirtualGraph& graph, bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1); std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) { for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
if (mergeGroup.size() < 2) if (mergeGroup.size() < 2)
@@ -303,7 +299,7 @@ bool coarsenGraph(const VirtualGraph& graph,
} }
bool coarsenGraphWithFallback(const VirtualGraph& graph, bool coarsenGraphWithFallback(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups, ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) { VirtualGraph& coarsenedGraph) {
if (coarsenGraph(graph, mergeGroups, coarsenedGraph)) if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
return true; return true;
@@ -330,7 +326,7 @@ bool coarsenGraphWithFallback(const VirtualGraph& graph,
return !acceptedMergeGroups.empty(); return !acceptedMergeGroups.empty();
} }
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) { std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph; VirtualGraph graph;
graph.nodes.resize(computeCount); graph.nodes.resize(computeCount);
graph.edges = aggregateEdges(edges); graph.edges = aggregateEdges(edges);
@@ -344,8 +340,8 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
} }
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
llvm::ArrayRef<SpatCompute> spatComputes, ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) { ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result; DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0); std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes)) for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
@@ -367,9 +363,7 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
return result; return result;
} }
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes, DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) {
GraphDCP graphDCP(spatComputes, edges); GraphDCP graphDCP(spatComputes, edges);
if (coresCount.getValue() > 0) if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue())); graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
@@ -383,12 +377,12 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
SpatCompute getOriginalSpatCompute(Operation* op) { SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op) if (!op)
return {}; return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) { while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp(); op = extract.getSource().getDefiningOp();
if (!op) if (!op)
return {}; return {};
} }
if (auto res = llvm::dyn_cast<SpatCompute>(op)) if (auto res = dyn_cast<SpatCompute>(op))
return res; return res;
return {}; return {};
} }