From 87922d994fa1e0f5371c603aad56dd82b5631fe8 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 22 Apr 2026 18:29:06 +0200 Subject: [PATCH] multiple-output spat computes --- src/PIM/Conversion/ONNXToSpatial/Common.hpp | 12 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 52 ++- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 420 +++++++++--------- .../SpatialToGraphviz/SpatialToGraphviz.cpp | 8 +- .../SpatialToPim/SpatialToPimPass.cpp | 20 +- src/PIM/Dialect/Spatial/Spatial.td | 2 +- src/PIM/Dialect/Spatial/SpatialOps.cpp | 13 +- .../DCPGraph/DCPAnalysis.cpp | 64 ++- .../DCPGraph/DCPAnalysis.hpp | 8 +- .../MergeComputeNodes/DCPGraph/Graph.cpp | 8 +- .../MergeComputeNodes/DCPGraph/Graph.hpp | 6 +- .../MergeComputeNodes/DCPGraph/Task.hpp | 18 +- .../MergeComputeNodes/DCPGraph/Utils.hpp | 8 +- .../MergeComputeNodesPass.cpp | 150 ++++--- src/PIM/Pass/CountInstructionPass.cpp | 2 +- validation/validate_one.py | 8 +- 16 files changed, 403 insertions(+), 396 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index da5626e..5cde963 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -182,7 +182,7 @@ auto createSpatCompute(RewriterT& rewriter, mlir::ValueRange inputs, BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); - auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); + auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) @@ -198,10 +198,10 @@ auto createSpatCompute(RewriterT& rewriter, if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); + return mlir::FailureOr(computeOp); } else { static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); @@ -219,7 +219,7 @@ auto createSpatCompute(RewriterT& rewriter, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { - auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); + auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) @@ -234,10 +234,10 @@ auto createSpatCompute(RewriterT& rewriter, if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); - return mlir::FailureOr(mlir::failure()); + return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); - return mlir::FailureOr(computeOp); + return mlir::FailureOr(computeOp); } else { static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 908cae6..7909b82 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() { if (coresCount != -1) { int computeOpsCount = 0; for (auto& op : entryFunc->getFunctionBody().front().getOperations()) - if (isa(op)) + if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { @@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); - if (isa_and_present(source.getDefiningOp())) { - auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source); + if (isa_and_present(source.getDefiningOp())) { + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; mapper.map(source, BB->getArgument(0)); auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0)); - inst->replaceAllUsesWith(newCompute); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); inst->erase(); return true; } @@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { auto sources = toRemoveOp.getInputs(); rewriter.setInsertionPointAfter(toRemoveOp); if (llvm::any_of( - sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { - auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); + sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); SmallVector sourceTypes; SmallVector sourceLoc; for (auto source : sources) { @@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) mapper.map(source, bbArg); auto newConcat = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); - inst->replaceAllUsesWith(newCompute); + spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); inst->erase(); return true; } @@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); IRRewriter rewriter(&getContext()); - SmallVector trivialComputes; - llvm::SmallSet toErase; + SmallVector trivialComputes; + llvm::SmallSet toErase; - for (auto compute : funcOp.getOps()) + for (auto compute : funcOp.getOps()) if (compute->hasOneUse()) { - auto user = dyn_cast(*compute->getUsers().begin()); + auto& use = *compute->getUses().begin(); + auto user = dyn_cast(use.getOwner()); - if (user && user.getInputs().size() == 1) + if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size()) trivialComputes.push_back(compute); } @@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { trivialComputes.pop_back(); continue; } - auto child = cast(*compute->getUsers().begin()); + auto& computeUse = *compute->getUses().begin(); + auto child = cast(computeUse.getOwner()); + auto usedResult = cast(computeUse.get()).getResultNumber(); + auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size(); rewriter.setInsertionPointAfter(compute.getOperation()); auto newCompute = - spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); + spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); newCompute.getProperties().setOperandSegmentSizes( {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); @@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); auto newTerminator = newCompute.getBody().front().getTerminator(); - mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); + mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult)); newTerminator->erase(); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); for (auto& op : child.getBody().front()) { @@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { toErase.insert(compute); if (newCompute->hasOneUse()) { - auto user = dyn_cast(*newCompute->getUsers().begin()); - if (user && user.getInputs().size() == 1) + auto& use = *newCompute->getUses().begin(); + auto user = dyn_cast(use.getOwner()); + if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size()) trivialComputes.push_back(newCompute); } } for (auto compute : toErase) { - compute.getResult(0).dropAllUses(); + for (Value result : compute->getResults()) + result.dropAllUses(); compute.erase(); } } @@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = - llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); + llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) markWeightAlways(constantOp); }); @@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { IRRewriter rewriter(&getContext()); - SmallVector computes(funcOp.getOps()); + SmallVector computes(funcOp.getOps()); for (auto compute : computes) { SmallVector promoteInput(compute.getInputs().size(), false); @@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun } auto newCompute = - spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); + spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); newCompute.getProperties().setOperandSegmentSizes( diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 00c946f..2d55277 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias, return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); } -static Value createIm2colCompute(Value x, - RankedTensorType xType, - RankedTensorType im2colType, - RankedTensorType rowType, - int64_t batchSize, - int64_t numChannelsIn, - int64_t xHeight, - int64_t xWidth, - int64_t wHeight, - int64_t wWidth, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - int64_t outWidth, - int64_t patchSize, - int64_t numPatches, - int64_t numPatchesPerBatch, - ConversionPatternRewriter& rewriter, - Location loc) { +static SmallVector createIm2colRowComputes(Value x, + RankedTensorType xType, + RankedTensorType im2colType, + RankedTensorType im2colRowType, + RankedTensorType gemmInputRowType, + int64_t batchSize, + int64_t numChannelsIn, + int64_t xHeight, + int64_t xWidth, + int64_t wHeight, + int64_t wWidth, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t outWidth, + int64_t patchSize, + int64_t numPatches, + int64_t numPatchesPerBatch, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { auto elemType = xType.getElementType(); constexpr size_t numInputs = 1; - auto im2colComputeOp = createSpatCompute(rewriter, loc, im2colType, {}, x, [&](Value xArg) { + const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); + SmallVector resultTypes(packedNumRows, gemmInputRowType); + auto im2colComputeOp = createSpatCompute(rewriter, loc, resultTypes, {}, x, [&](Value xArg) { Value paddedInput = xArg; // Pad input with zeros if needed: @@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x, Value row = tensor::CollapseShapeOp::create(rewriter, loc, - rowType, + im2colRowType, patch, SmallVector { {0}, @@ -256,121 +260,117 @@ static Value createIm2colCompute(Value x, rewriter.setInsertionPointAfter(im2colLoop); Value im2col = im2colLoop.getResult(0); - spatial::SpatYieldOp::create(rewriter, loc, im2col); - }); - return im2colComputeOp.getResult(0); -} -static Value createPackedIm2colRows(Value im2col, - RankedTensorType im2colType, - Type elemType, - int64_t numPatches, - int64_t patchSize, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - if (packFactor == 1) - return im2col; - - const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); - const int64_t paddedNumPatches = packedNumRows * packFactor; - auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); - auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); - auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) { - Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc); - Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, - loc, - groupedType, - paddedIm2col, - SmallVector { - {0, 1}, - {2} - }); - Value packedIm2col = tensor::CollapseShapeOp::create(rewriter, - loc, - packedType, - groupedIm2col, - SmallVector { - {0}, - {1, 2} - }); - spatial::SpatYieldOp::create(rewriter, loc, packedIm2col); - }); - return packedComputeOp.getResult(0); -} - -static Value createUnpackedOutput(Value packedOutput, - RankedTensorType gemmOutType, - RankedTensorType outType, - int64_t numPatches, - int64_t numChannelsOut, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - if (packFactor == 1) - return packedOutput; - - const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); - const int64_t paddedNumPatches = packedNumRows * packFactor; - auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); - auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); - auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) { - Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - packedOutputArg, - SmallVector { - {0}, - {1, 2} - }); - Value paddedOutput = tensor::CollapseShapeOp::create(rewriter, - loc, - paddedType, - expandedOutput, - SmallVector { - {0, 1}, - {2} - }); - - Value unpackedOutput = paddedOutput; - if (paddedNumPatches != numPatches) { - SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - unpackedOutput = - tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides); + Value gemmInputRows = im2col; + if (packFactor != 1) { + const int64_t paddedNumPatches = packedNumRows * packFactor; + auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); + auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); + Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc); + Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + paddedIm2col, + SmallVector { + {0, 1}, + {2} + }); + gemmInputRows = tensor::CollapseShapeOp::create(rewriter, + loc, + packedType, + groupedIm2col, + SmallVector { + {0}, + {1, 2} + }); } - spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput); + SmallVector rowResults; + rowResults.reserve(packedNumRows); + for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) { + SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(packFactor * patchSize)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + rowResults.push_back( + tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides)); + } + spatial::SpatYieldOp::create(rewriter, loc, rowResults); }); - return unpackComputeOp.getResult(0); + + SmallVector rows; + rows.reserve(im2colComputeOp.getNumResults()); + for (Value result : im2colComputeOp.getResults()) + rows.push_back(result); + return rows; } -static Value createCollectedConvOutput(Value gemmOut, +static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, + RankedTensorType gemmOutType, RankedTensorType nhwcType, RankedTensorType outType, + int64_t numPatches, + int64_t numChannelsOut, + int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { - auto collectComputeOp = - createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) { - Value gemmOutArg = gemmOutArgs.front(); - - // Restore to NCHW layout: - // [numPatches, numChannelsOut] - // -> [1, outHeight, outWidth, numChannelsOut] - // -> [1, numChannelsOut, outHeight, outWidth] - Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, - loc, - nhwcType, - gemmOutArg, - SmallVector { - {0, 1, 2}, - {3} + const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); + const int64_t paddedNumPatches = packedNumRows * packFactor; + auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { + Value gemmOut; + if (packFactor == 1) { + gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front() + : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult(); + } + else { + auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); + auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); + Value packedOutput = + gemmRowArgs.size() == 1 + ? gemmRowArgs.front() + : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult(); + Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedOutput, + SmallVector { + {0}, + {1, 2} }); - Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); - spatial::SpatYieldOp::create(rewriter, loc, nchwOut); + Value paddedOutput = tensor::CollapseShapeOp::create(rewriter, + loc, + paddedType, + expandedOutput, + SmallVector { + {0, 1}, + {2} + }); + + gemmOut = paddedOutput; + if (paddedNumPatches != numPatches) { + SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides); + } + } + + // Restore to NCHW layout: + // [numPatches, numChannelsOut] + // -> [1, outHeight, outWidth, numChannelsOut] + // -> [1, numChannelsOut, outHeight, outWidth] + Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, + loc, + nhwcType, + gemmOut, + SmallVector { + {0, 1, 2}, + {3} }); + Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); + spatial::SpatYieldOp::create(rewriter, loc, nchwOut); + }); return collectComputeOp.getResult(0); } @@ -487,11 +487,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, // Pass bias through directly; Gemm handles rank-1 C canonicalization. bool hasB = !isa(b.getDefiningOp()); - Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (hasB) { - gemmC = b; + gemmBias = b; biasDenseAttr = getDenseConstantAttr(b); biasMatrix = expandBiasIfNeeded(b, rewriter, loc); } @@ -500,94 +500,86 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, const int64_t effectiveMaxParallelPixels = (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; - Value im2col = createIm2colCompute(x, - xType, - im2colType, - rowType, - batchSize, - numChannelsIn, - xHeight, - xWidth, - wHeight, - wWidth, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - outWidth, - patchSize, - numPatches, - numPatchesPerBatch, - rewriter, - loc); + // Keep the standard im2col view of convolution: + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns + // and optionally repack several old rows into one GEMM row to use the available crossbar size better. + // + // The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only + // the row it needs instead of receiving a full packed tensor and slicing it locally. + auto gemmInputRowType = + RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType); + auto gemmOutputRowType = + RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); + SmallVector gemmInputRows = createIm2colRowComputes(x, + xType, + im2colType, + rowType, + gemmInputRowType, + batchSize, + numChannelsIn, + xHeight, + xWidth, + wHeight, + wWidth, + padHeightBegin, + padHeightEnd, + padWidthBegin, + padWidthEnd, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + outWidth, + patchSize, + numPatches, + numPatchesPerBatch, + effectiveMaxParallelPixels, + rewriter, + loc); - Value gemmOut; - if (effectiveMaxParallelPixels == 1) { - // Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels. - gemmOut = ONNXGemmOp::create(rewriter, - loc, - gemmOutType, - im2col, - wTrans, - gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - } - else { - // Keep the standard im2col view of convolution: - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // but repack several old rows into one new row so we use the available crossbar size better. - // - // We want to process N spatial pixels at the exact same time. Instead of doing N separate - // operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix - // containing N copies of W^T and concatenate N im2col rows into one longer row: - // A_packed: [ceil(numPatches / N), N * patchSize] - // B_packed: [N * patchSize, N * cOut] - // Y_packed: [ceil(numPatches / N), N * cOut] - // The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows. - const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); - auto packedOutType = - RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); + Value gemmB = buildPackedWeight(wDenseAttr, + wTrans, + wType, + numChannelsIn, + numChannelsOut, + wHeight, + wWidth, + patchSize, + effectiveMaxParallelPixels, + rewriter, + loc); + Value gemmC = buildPackedBias( + hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); - Value packedA = createPackedIm2colRows( - im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc); - Value packedB = buildPackedWeight(wDenseAttr, - wTrans, - wType, - numChannelsIn, - numChannelsOut, - wHeight, - wWidth, - patchSize, - effectiveMaxParallelPixels, - rewriter, - loc); - Value packedC = buildPackedBias( - hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); - Value packedOut = ONNXGemmOp::create(rewriter, - loc, - packedOutType, - packedA, - packedB, - packedC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - gemmOut = createUnpackedOutput( - packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); + SmallVector gemmRows; + gemmRows.reserve(gemmInputRows.size()); + for (Value gemmInputRow : gemmInputRows) { + Value gemmRow = ONNXGemmOp::create(rewriter, + loc, + gemmOutputRowType, + gemmInputRow, + gemmB, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + gemmRows.push_back(gemmRow); } - rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc)); + rewriter.replaceOp(convOp, + createCollectedConvOutput(gemmRows, + convOp.getType(), + gemmOutType, + nhwcType, + outType, + numPatches, + numChannelsOut, + effectiveMaxParallelPixels, + rewriter, + loc)); return success(); } diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index bbe8aa0..822fe63 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -42,15 +42,15 @@ private: raw_ostream& os; /** - * Draws the subgraph for a given spatial::SpatWeightedCompute, including: + * Draws the subgraph for a given spatial::SpatCompute, including: * 1. Input nodes (block arguments) * 2. Operations * 3. Edges between yield (output) and its users * - * @param op The spatial::SpatWeightedCompute to draw the subgraph for. + * @param op The spatial::SpatCompute to draw the subgraph for. * @param computeNum The number of the compute operation. */ - void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { + void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) { os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n" << "\t\tstyle=filled;\n" << "\t\tcolor=lightblue;\n"; @@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() { // 1. Print their subgraph // 2. Print the edges from its inputs to its outputs for (Operation& op : func.getOps()) { - if (auto computeOp = dyn_cast(op)) { + if (auto computeOp = dyn_cast(op)) { drawComputeOpSubgraph(computeOp, computeNum++); } else if (auto concatOp = dyn_cast(op)) { diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 3621336..ecc5d0d 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -62,7 +62,7 @@ private: void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); - void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, + void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, unsigned int argIndex, Value channelSourceOp, Value consumerValue, @@ -73,7 +73,7 @@ private: void annotateChannelCoreIds(func::FuncOp funcOp); void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter); - void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); + void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); @@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) { auto walkUses = [&](Value currentValue, auto& self) -> void { for (OpOperand& use : currentValue.getUses()) { Operation* owner = use.getOwner(); - if (isa(owner)) { + if (isa(owner)) { leafUserCount++; continue; } @@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() { markOpToRemove(receiveOp); runOnReceiveOp(receiveOp, rewriter); } - for (auto computeOp : funcOp.getOps()) { + for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); } @@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() { dumpModule(moduleOp, "pim0"); } -void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { +void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); auto& block = computeOp.getRegion().front(); @@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) - if (auto computeOp = dyn_cast(op)) { + if (auto computeOp = dyn_cast(op)) { unsigned numComputeWeights = computeOp.getWeights().size(); for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { TypedValue tensorSource; @@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { tensorSource = cast>(sliceOp.getSource()); - if (isa(tensorSource.getDefiningOp())) + if (isa(tensorSource.getDefiningOp())) continue; ArrayRef sourceShape = tensorSource.getType().getShape(); @@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu tensorSource = cast>(computeOpInput); // Compute results must be transferred through channels via send/receive - if (isa(tensorSource.getDefiningOp())) + if (isa(tensorSource.getDefiningOp())) continue; BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); @@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu return success(); } -void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, +void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, unsigned int argIndex, Value channelSourceOp, Value consumerValue, @@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp, auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void { for (OpOperand& use : currentValue.getUses()) { Operation* owner = use.getOwner(); - if (auto computeUser = dyn_cast(owner)) { + if (auto computeUser = dyn_cast(owner)) { replaceBlockArgumentWithRecvOp( computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter); continue; diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index c8f419e..594d159 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> { // Execution //===----------------------------------------------------------------------===// -def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { +def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { let summary = "Compute region with attached constant weights"; let arguments = (ins diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 9c1ee13..c3f1b3b 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -119,7 +119,7 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, } llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { - auto wcomputeOp = dyn_cast(weigthedOp->getParentOp()); + auto wcomputeOp = dyn_cast(weigthedOp->getParentOp()); if (wcomputeOp) return cast(wcomputeOp.getWeights()[weightIndex].getType()).getShape(); @@ -134,7 +134,7 @@ llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigth LogicalResult SpatWeightedMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) - return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); + return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); @@ -155,7 +155,7 @@ LogicalResult SpatWeightedMVMOp::verify() { LogicalResult SpatWeightedVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) - return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); + return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); @@ -200,9 +200,8 @@ LogicalResult SpatVMaxOp::verify() { return OpTrait::impl::verifySameOperandsAndResultType(*this); } -LogicalResult SpatWeightedCompute::verify() { - // Check that it has a terminator, it is a yieldOp, and it has a single - // operand with the same type as the result +LogicalResult SpatCompute::verify() { + // Check that the terminator yields the same number and types as the compute results. auto& block = getBody().front(); if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); @@ -257,7 +256,7 @@ LogicalResult SpatWeightedCompute::verify() { return success(); } -LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { +LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { Block& block = getBody().front(); if (!llvm::hasSingleElement(block)) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index bf03e56..27042b3 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -74,15 +74,15 @@ std::vector aggregateEdges(llvm::ArrayRef edges) { return aggregatedEdges; } -VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef spatWeightedComputes, +VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef spatComputes, llvm::ArrayRef edges) { VirtualGraph graph; - graph.nodes.reserve(spatWeightedComputes.size()); - for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { + graph.nodes.reserve(spatComputes.size()); + for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) { VirtualNode node; node.originalComputeIndices.push_back(index); - node.weight = getSpatComputeWeight(spatWeightedCompute); - node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute); + node.weight = getSpatComputeWeight(spatCompute); + node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute); graph.nodes.push_back(std::move(node)); } graph.edges = aggregateEdges(edges); @@ -344,22 +344,22 @@ std::vector computeOriginalTopologicalOrder(size_t computeCount, llvm::A } DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, - llvm::ArrayRef spatWeightedComputes, + llvm::ArrayRef spatComputes, llvm::ArrayRef originalEdges) { DCPAnalysisResult result; - std::vector originalToVirtualNode(spatWeightedComputes.size(), 0); + std::vector originalToVirtualNode(spatComputes.size(), 0); for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes)) for (size_t originalIndex : virtualNode.originalComputeIndices) originalToVirtualNode[originalIndex] = virtualNodeIndex; - auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges); + auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges); result.dominanceOrderCompute.reserve(dominanceOrder.size()); for (size_t originalIndex : dominanceOrder) { - SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex]; + SpatCompute spatCompute = spatComputes[originalIndex]; size_t cpu = originalToVirtualNode[originalIndex]; - result.dominanceOrderCompute.push_back(spatWeightedCompute); - result.computeToCpuMap[spatWeightedCompute] = cpu; - result.cpuToLastComputeMap[cpu] = spatWeightedCompute; + result.dominanceOrderCompute.push_back(spatCompute); + result.computeToCpuMap[spatCompute] = cpu; + result.cpuToLastComputeMap[cpu] = spatCompute; } for (auto [cpu, lastCompute] : result.cpuToLastComputeMap) @@ -367,10 +367,10 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, return result; } -DCPAnalysisResult runLegacyDcp(llvm::ArrayRef spatWeightedComputes, +DCPAnalysisResult runLegacyDcp(llvm::ArrayRef spatComputes, llvm::ArrayRef edges, MLIRContext* context) { - GraphDCP graphDCP(spatWeightedComputes, edges); + GraphDCP graphDCP(spatComputes, edges); if (coresCount.getValue() > 0) graphDCP.setMaxCpuCount(static_cast(coresCount.getValue())); graphDCP.setContext(context); @@ -380,7 +380,7 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef spatWeightedC } // namespace -SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) { +SpatCompute getOriginalSpatCompute(Operation* op) { if (!op) return {}; while (auto extract = llvm::dyn_cast(op)) { @@ -388,39 +388,33 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) { if (!op) return {}; } - if (auto res = llvm::dyn_cast(op)) + if (auto res = llvm::dyn_cast(op)) return res; return {}; } DCPAnalysisResult DCPAnalysis::run() { - SmallVector spatWeightedComputes; + SmallVector spatComputes; SmallVector edges; for (auto& region : entryOp->getRegions()) - for (SpatWeightedCompute spatWeightedCompute : region.getOps()) - spatWeightedComputes.push_back(spatWeightedCompute); + for (SpatCompute spatCompute : region.getOps()) + spatComputes.push_back(spatCompute); - for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { - for (Value input : spatWeightedCompute.getInputs()) { - if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) { - auto producerIt = llvm::find(spatWeightedComputes, producerCompute); - assert(producerIt != spatWeightedComputes.end()); - auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt); - ResultRange outputs = producerCompute.getResults(); - int64_t totalSize = 0; - for (auto output : outputs) { - ShapedType resultType = cast(output.getType()); - totalSize += getSizeInBytes(resultType); - } - edges.push_back({indexStartEdge, indexEndEdge, totalSize}); + for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) { + for (Value input : spatCompute.getInputs()) { + if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) { + auto producerIt = llvm::find(spatComputes, producerCompute); + assert(producerIt != spatComputes.end()); + auto indexStartEdge = std::distance(spatComputes.begin(), producerIt); + edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast(input.getType()))}); } } } if (dcpCriticalWindowSize.getValue() == 0) - return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext()); + return runLegacyDcp(spatComputes, edges, entryOp->getContext()); - VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges); + VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges); std::set> seenCriticalWindows; while (virtualGraph.nodes.size() > 1) { TimingInfo timing = computeTiming(virtualGraph); @@ -446,7 +440,7 @@ DCPAnalysisResult DCPAnalysis::run() { break; } - return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges); + return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges); } } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index 472e51f..8b1e1d3 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -10,10 +10,10 @@ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" struct DCPAnalysisResult { - std::vector dominanceOrderCompute; - llvm::DenseMap computeToCpuMap; - llvm::DenseSet isLastComputeOfCpu; - llvm::DenseMap cpuToLastComputeMap; + std::vector dominanceOrderCompute; + llvm::DenseMap computeToCpuMap; + llvm::DenseSet isLastComputeOfCpu; + llvm::DenseMap cpuToLastComputeMap; }; namespace onnx_mlir { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp index dfbf6fd..b24d3b1 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp @@ -1260,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() { auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size()); ret.dominanceOrderCompute.reserve(dominanceOrder.size()); for (auto elem : dominanceOrder) - ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute()); + ret.dominanceOrderCompute.push_back(elem->getSpatCompute()); for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) { const CpuTaskList* tasks = findCpuTasks(cpu); @@ -1268,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() { continue; size_t i = 0; for (auto node : *tasks) { - ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu; + ret.computeToCpuMap[node->getSpatCompute()] = cpu; if (i++ == tasks->size() - 1) { - ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute()); - ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); + ret.isLastComputeOfCpu.insert(node->getSpatCompute()); + ret.cpuToLastComputeMap[cpu] = node->getSpatCompute(); } } } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp index a66690f..7e9f4a5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp @@ -115,11 +115,11 @@ private: public: void runDcp(); - GraphDCP(llvm::ArrayRef spatWeightedComputes, + GraphDCP(llvm::ArrayRef spatComputes, llvm::ArrayRef edges) : nodes(), cpuTasks(), cpuCrossbarUsage() { - for (auto spatWeightedCompute : spatWeightedComputes) - nodes.emplace_back(spatWeightedCompute); + for (auto spatCompute : spatComputes) + nodes.emplace_back(spatCompute); for (auto [start, end, weight] : edges) makeEdge(start, end, weight); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp index 2290e20..1cd6ec5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Task.hpp @@ -8,7 +8,7 @@ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" class TaskDCP : public onnx_mlir::LabeledListNode { - onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute; + onnx_mlir::spatial::SpatCompute spatCompute; Time aest; Time alst; std::optional scheduledCpu; @@ -38,22 +38,22 @@ public: std::vector parents; std::vector children; TaskDCP() = default; - TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) + TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute) : onnx_mlir::LabeledListNode(), - spatWeightedCompute(spatWeightedCompute), + spatCompute(spatCompute), aest(0), alst(0), scheduledCpu(), - weight(getSpatComputeWeight(spatWeightedCompute)), + weight(getSpatComputeWeight(spatCompute)), baseWeight(weight), - crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)), + crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)), syntheticId(-1), parents(), children() {} TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0) : onnx_mlir::LabeledListNode(), - spatWeightedCompute(), + spatCompute(), aest(0), alst(0), scheduledCpu(), @@ -90,14 +90,14 @@ public: void setAlst(Time value) { alst = value; } bool hasDescendant(TaskDCP* child); int64_t Id() const { - if (spatWeightedCompute) - return reinterpret_cast(spatWeightedCompute.getAsOpaquePointer()); + if (spatCompute) + return reinterpret_cast(spatCompute.getAsOpaquePointer()); return syntheticId; } bool isCriticalPath() const { return alst == aest; } bool isScheduled() const { return scheduledCpu.has_value(); } - onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; } + onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; } void setFlag(long long val) { flag = val; } long long getFlag() const { return flag; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp index fc5a010..a4c9465 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp @@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) { inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); } -inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { +inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) { constexpr Weight kOperationWeight = 100; Weight numOperations = 0; - for (auto& block : spatWeightedCompute.getBody()) + for (auto& block : spatCompute.getBody()) for ([[maybe_unused]] auto& op : block) numOperations = checkedAdd(numOperations, static_cast(1)); return checkedMultiply(numOperations, kOperationWeight); } -inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { +inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) { CrossbarUsage crossbarUsage = 0; - for (auto& region : spatWeightedCompute.getBody()) + for (auto& region : spatCompute.getBody()) for (auto& inst : region) if (llvm::isa(inst)) crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 57b551b..01f9ed0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -24,30 +24,29 @@ using namespace mlir; namespace onnx_mlir { namespace { -using SpatWeightedCompute = spatial::SpatWeightedCompute; +using SpatCompute = spatial::SpatCompute; struct ComputeValueResults { - // Value yielded by the yieldOp - Value innerValue; + SmallVector innerValues; + + Value get(size_t resultIndex) const { + assert(resultIndex < innerValues.size() && "compute result index out of range"); + return innerValues[resultIndex]; + } }; class LazyInsertComputeResult { using InsertPoint = mlir::IRRewriter::InsertPoint; ComputeValueResults computeResults; - Value channelValue; bool onlyChannel; - std::function channelSendInserter; - InsertPoint sendInsertPoint; - std::function>()> channelNewInserter; + std::function>(size_t)> channelNewInserter; public: LazyInsertComputeResult(ComputeValueResults computeValueResults, - std::function>()> channelNewInserter, + std::function>(size_t)> channelNewInserter, bool isOnlyChannel) : computeResults(computeValueResults), onlyChannel(isOnlyChannel), - channelSendInserter(nullptr), - sendInsertPoint({}), channelNewInserter(channelNewInserter) {} struct ChannelOrLocalOp { @@ -57,12 +56,12 @@ public: bool onlyChanneled() const { return onlyChannel; } - ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) { + ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) { + Value innerValue = computeResults.get(resultIndex); - auto [newChannelValue, senderInserter] = channelNewInserter(); - channelValue = newChannelValue; - channelSendInserter = senderInserter; - auto* block = computeResults.innerValue.getParentBlock(); + auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex); + InsertPoint sendInsertPoint; + auto* block = innerValue.getParentBlock(); if (!block->empty() && isa(block->back())) sendInsertPoint = InsertPoint(block, --block->end()); else @@ -70,28 +69,30 @@ public: if (currentCompute) { for (auto& block : currentCompute.getBody()) if (&block == sendInsertPoint.getBlock()) - return {computeResults.innerValue, false}; + return {innerValue, false}; } channelSendInserter(sendInsertPoint); return {channelValue, true}; } - ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); } + ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) { + return getAsChannelValueAndInsertSender({}, resultIndex); + } }; struct MergeComputeNodesPass : PassWrapper> { private: - DenseMap newComputeNodeResults; - DenseMap oldToNewComputeMap; - DenseMap cpuToNewComputeMap; + DenseMap newComputeNodeResults; + DenseMap oldToNewComputeMap; + DenseMap cpuToNewComputeMap; public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass) StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; } StringRef getDescription() const override { - return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total " + return "Merge Spatial-Compute-Nodes in order to reduce the total " "execution time"; } @@ -105,22 +106,22 @@ public: for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); if (!cpuToNewComputeMap.contains(cpu)) { - ValueTypeRange newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes(); - auto [newWeightedCompute, computeValueResult] = createNewComputeNode( - currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode)); - cpuToNewComputeMap[cpu] = newWeightedCompute; + ValueTypeRange newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes(); + auto [newCompute, computeValueResult] = createNewComputeNode( + currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode)); + cpuToNewComputeMap[cpu] = newCompute; newComputeNodeResults.insert( std::make_pair(currentComputeNode, createLazyComputeResult( - newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); + newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); } else { - auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode( + auto [newCompute, computeValueResult] = mergeIntoComputeNode( cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode)); newComputeNodeResults.insert( std::make_pair(currentComputeNode, createLazyComputeResult( - newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); + newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); } } @@ -134,8 +135,8 @@ public: } private: - std::pair createNewComputeNode( - SpatWeightedCompute oldWeightedCompute, ValueTypeRange newWeightedComputeType, bool lastCompute) { + std::pair createNewComputeNode( + SpatCompute oldCompute, ValueTypeRange newComputeType, bool lastCompute) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); @@ -148,50 +149,53 @@ private: llvm::SmallVector newBBOperandType; llvm::SmallVector newBBLocations; - for (auto arg : oldWeightedCompute.getWeights()) + for (auto arg : oldCompute.getWeights()) newComputeOperand.push_back(arg); - for (auto arg : oldWeightedCompute.getInputs()) - if (!llvm::isa_and_present(arg.getDefiningOp())) { + for (auto arg : oldCompute.getInputs()) + if (!llvm::isa_and_present(arg.getDefiningOp())) { newComputeOperand.push_back(arg); newBBOperandType.push_back(arg.getType()); newBBLocations.push_back(loc); } - auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand); + auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand); rewriter.createBlock( - &newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations); - newWeightedCompute.getProperties().setOperandSegmentSizes( - {(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()}); + &newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations); + newCompute.getProperties().setOperandSegmentSizes( + {(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()}); - auto& newBB = newWeightedCompute.getBody().front(); - auto& oldBB = oldWeightedCompute.getBody().front(); + auto& newBB = newCompute.getBody().front(); + auto& oldBB = oldCompute.getBody().front(); rewriter.setInsertionPointToEnd(&newBB); int indexNew = 0; - size_t indexOld = oldWeightedCompute.getWeights().size(); - size_t indexOldStart = oldWeightedCompute.getWeights().size(); - for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) { - if (!llvm::isa_and_present(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) { + size_t indexOld = oldCompute.getWeights().size(); + size_t indexOldStart = oldCompute.getWeights().size(); + for (; indexOld < oldCompute.getNumOperands(); ++indexOld) { + if (!llvm::isa_and_present(oldCompute.getOperand(indexOld).getDefiningOp())) { mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++)); } else { auto argWeightCompute = - llvm::dyn_cast_if_present(oldWeightedCompute.getOperand(indexOld).getDefiningOp()); + llvm::dyn_cast_if_present(oldCompute.getOperand(indexOld).getDefiningOp()); + auto argResultIndex = cast(oldCompute.getOperand(indexOld)).getResultNumber(); LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); - auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(); + auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex); assert(isChannel == true); - spatial::SpatChannelReceiveOp receiveOp = - spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal); + spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create( + rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal); mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp); } } - for (auto& op : oldWeightedCompute.getOps()) { + for (auto& op : oldCompute.getOps()) { if (auto yield = dyn_cast(&op)) { - computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); + computeValueResults.innerValues.reserve(yield.getNumOperands()); + for (Value yieldOperand : yield.getOperands()) + computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand)); if (lastCompute) rewriter.clone(op, mapper); } @@ -199,16 +203,18 @@ private: rewriter.clone(op, mapper); } - for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses())) - if (isa(use.getOwner())) - use.assign(newWeightedCompute.getResult(0)); + for (auto& use : llvm::make_early_inc_range(oldCompute->getUses())) + if (isa(use.getOwner())) { + auto resultIndex = cast(use.get()).getResultNumber(); + use.assign(newCompute.getResult(resultIndex)); + } - oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute}); - return {cast(newWeightedCompute), computeValueResults}; + oldToNewComputeMap.insert({oldCompute, newCompute}); + return {cast(newCompute), computeValueResults}; } - std::pair - mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) { + std::pair + mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); @@ -239,14 +245,15 @@ private: // Insert receiveOp rewriter.setInsertionPointToEnd(&toBB); for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) { - if (auto argWeightCompute = llvm::dyn_cast_if_present(arg.getDefiningOp())) { + if (auto argWeightCompute = llvm::dyn_cast_if_present(arg.getDefiningOp())) { LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); + auto argResultIndex = cast(arg).getResultNumber(); LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal = - lazyArgWeight.getAsChannelValueAndInsertSender(toCompute); + lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex); if (channelOrLocal.isChannel) { spatial::SpatChannelReceiveOp receiveOp = - spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data); + spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data); mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult()); } else { @@ -286,7 +293,9 @@ private: }; for (auto& op : fromCompute.getOps()) { if (auto yield = dyn_cast(&op)) { - computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); + computeValueResults.innerValues.reserve(yield.getNumOperands()); + for (Value yieldOperand : yield.getOperands()) + computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand)); if (lastCompute) rewriter.clone(op, mapper); } @@ -299,33 +308,36 @@ private: } } - for (auto users : fromCompute->getUsers()) - if (auto funcRet = dyn_cast(users)) - funcRet.setOperand(0, toCompute.getResult(0)); + for (auto& use : llvm::make_early_inc_range(fromCompute->getUses())) + if (isa(use.getOwner())) { + auto resultIndex = cast(use.get()).getResultNumber(); + use.assign(toCompute.getResult(resultIndex)); + } oldToNewComputeMap.insert({fromCompute, toCompute}); - return {cast(toCompute), computeValueResults}; + return {cast(toCompute), computeValueResults}; } - LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute, + LazyInsertComputeResult createLazyComputeResult(SpatCompute compute, ComputeValueResults computeValueResults, bool lastCompute) { - func::FuncOp funcOp = cast(weightedCompute->getParentOp()); + func::FuncOp funcOp = cast(compute->getParentOp()); auto* context = &getContext(); auto loc = funcOp.getLoc(); IRRewriter rewriter(context); rewriter.setInsertionPointToStart(&funcOp.front()); auto savedChannelInsertPoint = rewriter.saveInsertionPoint(); - auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() { + auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) { IRRewriter rewriter(context); rewriter.restoreInsertionPoint(savedChannelInsertPoint); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context)); auto channelVal = channelOp.getResult(); - auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) { + auto insertVal = + [&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) { IRRewriter rewriter(context); rewriter.restoreInsertionPoint(sendInsertPoint); - auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue); + auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex)); return spatSend; }; std::pair> ret {channelVal, insertVal}; diff --git a/src/PIM/Pass/CountInstructionPass.cpp b/src/PIM/Pass/CountInstructionPass.cpp index 6d0d35b..2a8ad9e 100644 --- a/src/PIM/Pass/CountInstructionPass.cpp +++ b/src/PIM/Pass/CountInstructionPass.cpp @@ -31,7 +31,7 @@ struct CountInstructionPass : public PassWrapper()) { + for (auto computeOp : func.getOps()) { unsigned instructionCount = 0; instructionCount += computeOp.getBody().front().getOperations().size(); llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n"; diff --git a/validation/validate_one.py b/validation/validate_one.py index 3bca810..e75a687 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -26,6 +26,10 @@ STAGE_COUNT = len(STAGE_TITLES) GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") +def sanitize_output_name(name): + return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255]) + + @dataclass class ValidationResult: passed: bool @@ -205,7 +209,7 @@ def build_dump_ranges(config_path, outputs_descriptor): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): run_command( - ["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", + ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], cwd=simulator_dir, reporter=reporter, @@ -229,7 +233,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1 all_passed = True rows = [] for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): - csv_name = f"output{oi}_{name}.csv" + csv_name = f"output{oi}_{sanitize_output_name(name)}.csv" runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) passed = max_diff <= threshold