From 76a37e198f3a2537002c071f91c637ef4bc6ca91 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sat, 23 May 2026 11:17:36 +0200 Subject: [PATCH] better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops --- src/PIM/Dialect/Spatial/CMakeLists.txt | 2 - .../MaterializeMergeSchedule.cpp | 326 +++++- .../MergeComputeNodesPass.cpp | 9 - .../MergeComputeNodes/PostMergeCompaction.cpp | 532 --------- .../MergeComputeNodes/PostMergeCompaction.hpp | 12 - .../MergeComputeNodes/RegularOpCompaction.cpp | 1043 ----------------- .../MergeComputeNodes/RegularOpCompaction.hpp | 15 - 7 files changed, 291 insertions(+), 1648 deletions(-) delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp delete mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index bd54ee1..cf2b5ba 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -9,8 +9,6 @@ add_pim_library(SpatialOps SpatialOpsCanonicalization.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp - Transforms/MergeComputeNodes/PostMergeCompaction.cpp - Transforms/MergeComputeNodes/RegularOpCompaction.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 272c9df..e320916 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -557,22 +558,6 @@ Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t v return getOrCreateHostIndexConstant(anchor, value, state.constantFolder); } -SmallVector createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef values) { - SmallVector constants; - constants.reserve(values.size()); - for (int64_t value : values) - constants.push_back(createIndexConstant(state, anchor, value)); - return constants; -} - -SmallVector createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef values) { - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - return createIndexConstants(state, anchor, ArrayRef(widened)); -} - Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { SmallVector elements; elements.reserve(values.size()); @@ -644,6 +629,28 @@ Value createLaneIndexedIndexValue(MaterializerState& state, return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); } +Value createIndexedIndexValue( + MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + + if (allEqual(values)) + return createIndexConstant(state, anchor, values.front()); + + Value table = createIndexTensorConstant(state, anchor, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); +} + +Value createIndexedIndexValue( + MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + + if (allEqual(values)) + return createIndexConstant(state, anchor, values.front()); + + Value table = createIndexTensorConstant(state, anchor, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); +} + FailureOr> getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) { SmallVector peers; @@ -808,6 +815,53 @@ ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey ke return it->second; } +void appendScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); + Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); + Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId); + SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); +} + +void appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); + assert(channelIds.size() > 1 && "send loop is only useful for multiple sends"); + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + Value lowerBound = createIndexConstant(state, sourceClass.op, 0); + Value upperBound = createIndexConstant(state, sourceClass.op, static_cast(channelIds.size())); + Value step = createIndexConstant(state, sourceClass.op, 1); + + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToStart(loop.getBody()); + + Value index = loop.getInductionVar(); + Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); +} + void appendSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -819,9 +873,9 @@ void appendSend(MaterializerState& state, assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); assert(!channelIds.empty() && "expected at least one send"); - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - if (sourceClass.isBatch) { + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc); @@ -829,12 +883,13 @@ void appendSend(MaterializerState& state, return; } - for (auto index : llvm::seq(0, channelIds.size())) { - Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]); - Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]); - Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + if (channelIds.size() == 1) { + appendScalarSend( + state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); + return; } + + appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); } Value appendScalarReceive(MaterializerState& state, @@ -911,6 +966,212 @@ Value appendPackedScalarReceives(MaterializerState& state, return packed; } +std::optional getConstantIndexValue(Value value) { + if (auto constant = value.getDefiningOp()) + return constant.value(); + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) + return constantValue.getSExtValue(); + + return std::nullopt; +} + +bool getReceiveMetadata(SpatChannelReceiveOp receive, + int64_t& channelId, + int64_t& sourceCoreId, + int64_t& targetCoreId) { + // SpatChannelReceiveOp operands are: channel, source core, target core. + std::optional channel = getConstantIndexValue(receive->getOperand(0)); + std::optional source = getConstantIndexValue(receive->getOperand(1)); + std::optional target = getConstantIndexValue(receive->getOperand(2)); + if (!channel || !source || !target) + return false; + + channelId = *channel; + sourceCoreId = *source; + targetCoreId = *target; + return true; +} + +bool hasCompatibleConcatTypes(RankedTensorType concatType, RankedTensorType fragmentType, size_t fragmentCount) { + if (!concatType.hasStaticShape() || !fragmentType.hasStaticShape()) + return false; + if (concatType.getRank() != fragmentType.getRank()) + return false; + if (concatType.getRank() == 0) + return false; + if (concatType.getElementType() != fragmentType.getElementType()) + return false; + + if (concatType.getDimSize(0) != fragmentType.getDimSize(0) * static_cast(fragmentCount)) + return false; + + for (int64_t dim = 1; dim < concatType.getRank(); ++dim) + if (concatType.getDimSize(dim) != fragmentType.getDimSize(dim)) + return false; + + return true; +} + +Value createReceiveConcatLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + RankedTensorType concatType, + RankedTensorType fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(!channelIds.empty() && "expected at least one receive"); + + Value lowerBound = createIndexConstant(state, anchor, 0); + Value upperBound = createIndexConstant(state, anchor, static_cast(channelIds.size())); + Value step = createIndexConstant(state, anchor, 1); + + state.rewriter.setInsertionPoint(insertionPoint); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value index = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); + + Value received = + SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput(); + + Value firstOffset = index; + if (fragmentType.getDimSize(0) != 1) { + Value rowsPerFragment = createIndexConstant(state, anchor, fragmentType.getDimSize(0)); + firstOffset = arith::MulIOp::create(state.rewriter, loc, index, rowsPerFragment).getResult(); + } + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(fragmentType.getRank()); + sizes.reserve(fragmentType.getRank()); + strides.reserve(fragmentType.getRank()); + + offsets.push_back(firstOffset); + sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(0))); + strides.push_back(state.rewriter.getIndexAttr(1)); + + for (int64_t dim = 1; dim < fragmentType.getRank(); ++dim) { + offsets.push_back(state.rewriter.getIndexAttr(0)); + sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(dim))); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + Value next = tensor::InsertSliceOp::create(state.rewriter, loc, received, acc, offsets, sizes, strides).getResult(); + scf::YieldOp::create(state.rewriter, loc, next); + + return loop.getResult(0); +} + +bool compactReceiveConcat(MaterializerState& state, MaterializedClass& materializedClass, tensor::ConcatOp concat) { + auto dimAttr = concat->getAttrOfType("dim"); + if (!dimAttr || dimAttr.getInt() != 0) + return false; + + OperandRange inputs = concat->getOperands(); + if (inputs.size() < 2) + return false; + + SmallVector receives; + receives.reserve(inputs.size()); + + for (Value input : inputs) { + auto receive = input.getDefiningOp(); + if (!receive) + return false; + if (receive->getBlock() != concat->getBlock()) + return false; + if (!receive->getResult(0).hasOneUse()) + return false; + receives.push_back(receive); + } + + Operation* expected = concat.getOperation(); + for (SpatChannelReceiveOp receive : llvm::reverse(receives)) { + Operation* previous = expected->getPrevNode(); + if (previous != receive.getOperation()) + return false; + expected = previous; + } + + auto concatType = dyn_cast(concat->getResult(0).getType()); + auto fragmentType = dyn_cast(receives.front()->getResult(0).getType()); + if (!concatType || !fragmentType) + return false; + if (!hasCompatibleConcatTypes(concatType, fragmentType, receives.size())) + return false; + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(receives.size()); + sourceCoreIds.reserve(receives.size()); + targetCoreIds.reserve(receives.size()); + + for (SpatChannelReceiveOp receive : receives) { + if (receive->getResult(0).getType() != fragmentType) + return false; + + int64_t channelId = 0; + int64_t sourceCoreId = 0; + int64_t targetCoreId = 0; + if (!getReceiveMetadata(receive, channelId, sourceCoreId, targetCoreId)) + return false; + + channelIds.push_back(channelId); + sourceCoreIds.push_back(sourceCoreId); + targetCoreIds.push_back(targetCoreId); + } + + Value replacement = createReceiveConcatLoop(state, + materializedClass.op, + receives.front().getOperation(), + concatType, + fragmentType, + channelIds, + sourceCoreIds, + targetCoreIds, + concat.getLoc()); + + concat->getResult(0).replaceAllUsesWith(replacement); + state.rewriter.eraseOp(concat.getOperation()); + + for (SpatChannelReceiveOp receive : llvm::reverse(receives)) + state.rewriter.eraseOp(receive.getOperation()); + + return true; +} + +void compactReceiveConcats(MaterializerState& state) { + SmallVector, 16> concatOps; + + for (MaterializedClass& materializedClass : state.classes) + materializedClass.op->walk([&](tensor::ConcatOp concat) { concatOps.push_back({&materializedClass, concat}); }); + + for (auto [materializedClass, concat] : concatOps) + compactReceiveConcat(state, *materializedClass, concat); +} + LogicalResult emitClassToClassCommunication(MaterializerState& state, MaterializedClass& sourceClass, MaterializedClass& targetClass, @@ -1086,15 +1347,8 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val return success(); } -LogicalResult emitHostCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - (void) keys; - (void) loc; - +LogicalResult +emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) return success(); @@ -1115,7 +1369,7 @@ LogicalResult emitOutputFanout(MaterializerState& state, if (failed( emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); - if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) + if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) return failure(); state.availableValues[keys.front()][sourceClass.id] = payload; return success(); @@ -1129,7 +1383,7 @@ LogicalResult emitOutputFanout(MaterializerState& state, if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); - if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) + if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) return failure(); for (ProducerKey key : keys) @@ -1146,7 +1400,7 @@ FailureOr materializeWholeBatchInput( state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + uint32_t batchLaneCount = batch.getLaneCount(); SmallVector fragments; uint32_t lane = 0; @@ -1426,6 +1680,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(materializeInstanceSlot(state, instance))) return failure(); + compactReceiveConcats(state); + replaceHostUses(state); if (failed(eraseOldComputeOps(state))) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index ad7abe2..e5bd394 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -27,7 +27,6 @@ #include #include #include -#include #include #include #include @@ -35,7 +34,6 @@ #include #include "MaterializeMergeSchedule.hpp" -#include "PostMergeCompaction.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" @@ -660,13 +658,6 @@ public: emitMergeIrCounts("after-materialization", func); - /*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) { - signalPassFailure(); - return; - } - - emitMergeIrCounts("after-post-merge-compaction", func);*/ - { ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); if (!sortTopologically(&func.getBody().front())) { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp deleted file mode 100644 index cefaa1a..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp +++ /dev/null @@ -1,532 +0,0 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include -#include - -#include "PostMergeCompaction.hpp" -#include "RegularOpCompaction.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace { - -using SpatCompute = spatial::SpatCompute; -using SpatComputeBatch = spatial::SpatComputeBatch; - -bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; } - -class ScopedMergePhaseTimer { -public: - explicit ScopedMergePhaseTimer(StringRef phaseName) - : enabled(isMergeProfilingEnabled()), phase(phaseName.str()) { - if (enabled) - start = std::chrono::steady_clock::now(); - } - - ~ScopedMergePhaseTimer() { - if (!enabled) - return; - auto elapsed = std::chrono::steady_clock::now() - start; - double millis = std::chrono::duration(elapsed).count(); - llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n"; - } - -private: - bool enabled = false; - std::string phase; - std::chrono::steady_clock::time_point start; -}; - -std::optional getComputeCoreId(SpatCompute compute) { - if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return static_cast(coreIdAttr.getInt()); - return std::nullopt; -} - -static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase"; - -static FailureOr getConstantI64Value(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return constantValue.getSExtValue(); -} - -static FailureOr getConstantI32Value(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return static_cast(constantValue.getSExtValue()); -} - -static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op, - uint64_t& channelId, - uint32_t& sourceCoreId, - uint32_t& targetCoreId) { - FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); - FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); - FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); - if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) - return false; - channelId = static_cast(*constantChannelId); - sourceCoreId = static_cast(*constantSourceCoreId); - targetCoreId = static_cast(*constantTargetCoreId); - return true; -} - -static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op, - uint64_t& channelId, - uint32_t& sourceCoreId, - uint32_t& targetCoreId) { - FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); - FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); - FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); - if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) - return false; - channelId = static_cast(*constantChannelId); - sourceCoreId = static_cast(*constantSourceCoreId); - targetCoreId = static_cast(*constantTargetCoreId); - return true; -} - -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int64_t value : values) - constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); - return constants; -} - -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int32_t value : values) - constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); - return constants; -} - -std::optional getComputeRebatchPhase(SpatCompute compute) { - if (auto phaseAttr = compute->getAttrOfType(kRebatchPhaseAttrName)) - return static_cast(phaseAttr.getInt()); - return std::nullopt; -} - -struct RebatchKey { - unsigned inputCount = 0; - unsigned resultCount = 0; - unsigned weightCount = 0; - uint64_t phase = 0; - bool hasPhase = false; - uint64_t structureHash = 0; - - bool operator==(const RebatchKey& other) const { - return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount - && phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash; - } -}; - -struct RebatchKeyInfo { - static inline RebatchKey getEmptyKey() { return {std::numeric_limits::max(), 0, 0, 0, false, 0}; } - - static inline RebatchKey getTombstoneKey() { return {std::numeric_limits::max() - 1, 0, 0, 0, false, 0}; } - - static unsigned getHashValue(const RebatchKey& key) { - return static_cast( - llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash)); - } - - static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; } -}; - -uint64_t getTypeHash(Type type) { return reinterpret_cast(type.getAsOpaquePointer()); } - -uint64_t getValueHash(Value value) { return reinterpret_cast(value.getAsOpaquePointer()); } - -uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast(attr.getAsOpaquePointer()); } - -RebatchKey computeRebatchKey(SpatCompute compute) { - llvm::hash_code structureHash = - llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size()); - - for (Value weight : compute.getWeights()) - structureHash = llvm::hash_combine(structureHash, getValueHash(weight)); - if (std::optional phase = getComputeRebatchPhase(compute)) - structureHash = llvm::hash_combine(structureHash, *phase); - - Block& body = compute.getBody().front(); - structureHash = llvm::hash_combine(structureHash, body.getNumArguments()); - for (BlockArgument arg : body.getArguments()) - structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType())); - - for (Operation& op : body) { - structureHash = llvm::hash_combine( - structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions()); - for (Type type : op.getResultTypes()) - structureHash = llvm::hash_combine(structureHash, getTypeHash(type)); - for (NamedAttribute attr : op.getAttrs()) - structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue())); - } - - std::optional phase = getComputeRebatchPhase(compute); - return {static_cast(compute.getInputs().size()), - static_cast(compute.getResultTypes().size()), - static_cast(compute.getWeights().size()), - phase.value_or(0), - phase.has_value(), - static_cast(structureHash)}; -} - -bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { - if (!lhs || !rhs) - return false; - if (lhs.getInputs().size() != rhs.getInputs().size()) - return false; - if (lhs.getResultTypes() != rhs.getResultTypes()) - return false; - if (lhs.getWeights().size() != rhs.getWeights().size()) - return false; - if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs)) - return false; - if (!llvm::equal(lhs.getWeights(), rhs.getWeights())) - return false; - - auto& lhsBlock = lhs.getBody().front(); - auto& rhsBlock = rhs.getBody().front(); - if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) - return false; - - DenseMap mappedValues; - for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) { - if (lhsArg.getType() != rhsArg.getType()) - return false; - mappedValues[lhsArg] = rhsArg; - } - auto lhsIt = lhsBlock.begin(); - auto rhsIt = rhsBlock.begin(); - for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) { - Operation& lhsOp = *lhsIt; - Operation& rhsOp = *rhsIt; - - if (lhsOp.getName() != rhsOp.getName()) - return false; - if (lhsOp.getNumOperands() != rhsOp.getNumOperands()) - return false; - if (lhsOp.getNumResults() != rhsOp.getNumResults()) - return false; - if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0) - return false; - - for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) { - auto mapped = mappedValues.find(lhsOperand); - if (mapped != mappedValues.end()) { - if (mapped->second != rhsOperand) - return false; - continue; - } - if (lhsOperand != rhsOperand) - return false; - } - - if (auto lhsReceive = dyn_cast(lhsOp)) { - auto rhsReceive = cast(rhsOp); - if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType()) - return false; - } - else if (auto lhsSend = dyn_cast(lhsOp)) { - auto rhsSend = cast(rhsOp); - if (lhsSend.getInput().getType() != rhsSend.getInput().getType()) - return false; - } - else if (lhsOp.getAttrs() != rhsOp.getAttrs()) { - return false; - } - - if (lhsOp.getResultTypes() != rhsOp.getResultTypes()) - return false; - for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults())) - mappedValues[lhsResult] = rhsResult; - } - - return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); -} - -void rebatchEquivalentComputes(func::FuncOp funcOp) { - IRRewriter rewriter(funcOp.getContext()); - OperationFolder constantFolder(funcOp.getContext()); - SmallVector computes(funcOp.getOps()); - DenseSet consumed; - DenseMap computeOrder; - DenseMap, RebatchKeyInfo> candidatesByKey; - - for (auto [index, compute] : llvm::enumerate(computes)) { - computeOrder[compute.getOperation()] = index; - if (compute.getInputs().size() <= 1 && compute.getResults().empty()) - candidatesByKey[computeRebatchKey(compute)].push_back(compute); - } - - for (size_t index = 0; index < computes.size(); ++index) { - auto anchor = computes[index]; - if (consumed.contains(anchor)) - continue; - if (anchor.getInputs().size() > 1) - continue; - if (!anchor.getResults().empty()) - continue; - - SmallVector group {anchor}; - llvm::SmallDenseSet usedCoreIds; - if (auto coreId = getComputeCoreId(anchor)) - usedCoreIds.insert(*coreId); - - auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor)); - if (bucketIt == candidatesByKey.end()) - continue; - - for (auto candidate : bucketIt->second) { - if (computeOrder.lookup(candidate.getOperation()) <= index) - continue; - if (consumed.contains(candidate)) - continue; - if (!areEquivalentForRebatch(anchor, candidate)) - continue; - - if (auto coreId = getComputeCoreId(candidate)) - if (!usedCoreIds.insert(*coreId).second) - continue; - - group.push_back(candidate); - } - - if (group.size() <= 1) - continue; - - auto insertionAnchor = group.front(); - if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) { - llvm::stable_sort( - group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); }); - } - - SmallVector weights; - weights.reserve(group.size() * anchor.getWeights().size()); - SmallVector inputs; - inputs.reserve(group.size() * anchor.getInputs().size()); - SmallVector coreIds; - coreIds.reserve(group.size()); - bool haveAllCoreIds = true; - for (auto compute : group) { - llvm::append_range(weights, compute.getWeights()); - llvm::append_range(inputs, compute.getInputs()); - auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName); - if (!coreIdAttr) - haveAllCoreIds = false; - else if (haveAllCoreIds) - coreIds.push_back(static_cast(coreIdAttr.getInt())); - } - - rewriter.setInsertionPoint(insertionAnchor); - auto rebatched = SpatComputeBatch::create(rewriter, - insertionAnchor.getLoc(), - TypeRange {}, - rewriter.getI32IntegerAttr(static_cast(group.size())), - ValueRange(weights), - ValueRange(inputs)); - rebatched.getProperties().setOperandSegmentSizes( - {static_cast(weights.size()), static_cast(inputs.size())}); - if (haveAllCoreIds) - rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (BlockArgument arg : anchor.getBody().front().getArguments()) { - blockArgTypes.push_back(arg.getType()); - blockArgLocs.push_back(arg.getLoc()); - } - auto* newBlock = - rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToEnd(newBlock); - - IRMapping mapper; - auto& anchorBlock = anchor.getBody().front(); - for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments())) - mapper.map(oldArg, newArg); - auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); }); - for (Operation& anchorOp : anchorBlock) { - if (auto receiveOp = dyn_cast(&anchorOp)) { - struct BatchReceiveEntry { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - }; - SmallVector entries; - entries.reserve(group.size()); - for (auto [groupIndex, compute] : llvm::enumerate(group)) { - auto groupReceive = cast(&*opIts[groupIndex]); - BatchReceiveEntry entry; - if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId)) - return; - entries.push_back(entry); - ++opIts[groupIndex]; - } - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(group.size()); - sourceCoreIds.reserve(group.size()); - targetCoreIds.reserve(group.size()); - for (const BatchReceiveEntry& entry : entries) { - channelIds.push_back(static_cast(entry.channelId)); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - } - SmallVector channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder); - SmallVector sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder); - SmallVector targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder); - auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, - receiveOp.getLoc(), - receiveOp.getOutput().getType(), - channelIdValues, - sourceCoreIdValues, - targetCoreIdValues); - mapper.map(receiveOp.getOutput(), batchReceive.getOutput()); - continue; - } - - if (auto sendOp = dyn_cast(&anchorOp)) { - struct BatchSendEntry { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - }; - SmallVector entries; - entries.reserve(group.size()); - for (auto [groupIndex, compute] : llvm::enumerate(group)) { - auto groupSend = cast(&*opIts[groupIndex]); - BatchSendEntry entry; - if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId)) - return; - entries.push_back(entry); - ++opIts[groupIndex]; - } - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(group.size()); - sourceCoreIds.reserve(group.size()); - targetCoreIds.reserve(group.size()); - for (const BatchSendEntry& entry : entries) { - channelIds.push_back(static_cast(entry.channelId)); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - } - SmallVector channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder); - SmallVector sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder); - SmallVector targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder); - spatial::SpatChannelSendBatchOp::create(rewriter, - sendOp.getLoc(), - channelIdValues, - sourceCoreIdValues, - targetCoreIdValues, - mapper.lookup(sendOp.getInput())); - continue; - } - - if (isa(anchorOp)) { - for (auto& opIt : opIts) - ++opIt; - spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {}); - continue; - } - - Operation* cloned = rewriter.clone(anchorOp, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - for (auto& opIt : opIts) - ++opIt; - } - - for (auto compute : group) { - compute->removeAttr(kRebatchPhaseAttrName); - consumed.insert(compute); - rewriter.eraseOp(compute); - } - } - - for (auto compute : funcOp.getOps()) - compute->removeAttr(kRebatchPhaseAttrName); -} - -void cleanupDeadPackingOps(func::FuncOp funcOp) { - auto eraseUnusedOps = [&](auto tag) { - using OpTy = decltype(tag); - SmallVector ops; - funcOp.walk([&](OpTy op) { ops.push_back(op); }); - for (auto op : llvm::reverse(ops)) - if (op->use_empty()) - op.erase(); - }; - eraseUnusedOps(tensor::ExtractSliceOp {}); - eraseUnusedOps(spatial::SpatConcatOp {}); - eraseUnusedOps(spatial::SpatExtractRowsOp {}); -} - -} // namespace - -LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) { - { - ScopedMergePhaseTimer timer("order-bilateral-channel-ops"); - orderBilateralChannelOps(funcOp); - } - { - ScopedMergePhaseTimer timer("rebatch-equivalent-computes"); - rebatchEquivalentComputes(funcOp); - } - { - ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1"); - compactScalarChannelRuns(funcOp, nextChannelId); - } - { - ScopedMergePhaseTimer timer("compact-batch-channel-runs-1"); - compactBatchChannelRuns(funcOp); - } - { - ScopedMergePhaseTimer timer("compact-regular-op-runs"); - compactRegularOpRuns(funcOp); - } - { - ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs"); - compactRowWiseWvmmRuns(funcOp); - } - { - ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2"); - compactScalarChannelRuns(funcOp, nextChannelId); - } - { - ScopedMergePhaseTimer timer("compact-batch-channel-runs-2"); - compactBatchChannelRuns(funcOp); - } - { - ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); - cleanupDeadPackingOps(funcOp); - } - - return success(); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp deleted file mode 100644 index 2797fc2..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Support/LogicalResult.h" - -#include - -namespace onnx_mlir { - -mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId); - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp deleted file mode 100644 index a13c66f..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ /dev/null @@ -1,1043 +0,0 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" - -#include "RegularOpCompaction.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace { - -enum class RegularStepKind { - Wvmm, - VAddLhs, - VAddRhs, -}; - -struct RegularStep { - RegularStepKind kind; - Value weight; - Value invariantOperand; - Type resultType; -}; - -struct RegularChunk { - Operation* startOp = nullptr; - SmallVector ops; - SmallVector steps; - Value input; - Value output; -}; - -struct RegularCompactionResult { - bool changed = false; - Operation* resumeAfter = nullptr; -}; - -template -struct ConsecutiveRun { - SmallVector ops; - Block::iterator end; -}; - -template -static ConsecutiveRun -collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) { - ConsecutiveRun run; - run.end = start; - while (run.end != blockEnd) { - auto current = dyn_cast(&*run.end); - if (!current || !predicate(current)) - break; - run.ops.push_back(current); - ++run.end; - } - return run; -} - -static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) { - return (static_cast(sourceCoreId) << 32) | static_cast(targetCoreId); -} - -static FailureOr getConstantI64Value(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return constantValue.getSExtValue(); -} - -static FailureOr getConstantI32Value(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return static_cast(constantValue.getSExtValue()); -} - -static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op, - uint64_t& channelId, - uint32_t& sourceCoreId, - uint32_t& targetCoreId) { - FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); - FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); - FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); - if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) - return false; - channelId = static_cast(*constantChannelId); - sourceCoreId = static_cast(*constantSourceCoreId); - targetCoreId = static_cast(*constantTargetCoreId); - return true; -} - -static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op, - uint64_t& channelId, - uint32_t& sourceCoreId, - uint32_t& targetCoreId) { - FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); - FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); - FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); - if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) - return false; - channelId = static_cast(*constantChannelId); - sourceCoreId = static_cast(*constantSourceCoreId); - targetCoreId = static_cast(*constantTargetCoreId); - return true; -} - -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int64_t value : values) - constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); - return constants; -} - -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int32_t value : values) - constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); - return constants; -} - -static SmallVector getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) { - SmallVector defs; - defs.reserve(metadataOperandCount); - for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) { - Operation* def = channelOp->getOperand(operandIndex).getDefiningOp(); - auto constantOp = dyn_cast_or_null(def); - if (!constantOp || def->getBlock() != channelOp->getBlock()) - continue; - defs.push_back(def); - } - llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); }); - return defs; -} - -static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) { - for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3)) - metadataDef->moveBefore(insertionPoint); - channelOp->moveBefore(insertionPoint); -} - -static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) { - for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3)) - metadataDef->moveBefore(block, insertionPoint); - channelOp->moveBefore(block, insertionPoint); -} - -static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) { - if (values.empty() || !values.front().hasOneUse()) - return {}; - - OpOperand& firstUse = *values.front().getUses().begin(); - auto concatOp = dyn_cast(firstUse.getOwner()); - if (!concatOp) - return {}; - - startOperandIndex = firstUse.getOperandNumber(); - for (auto [index, value] : llvm::enumerate(values)) { - if (!value.hasOneUse()) - return {}; - OpOperand& use = *value.getUses().begin(); - if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index) - return {}; - } - - return concatOp; -} - -static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp, - unsigned startOperandIndex, - unsigned operandCount, - Value packedValue, - IRRewriter& rewriter) { - SmallVector newInputs; - newInputs.reserve(concatOp.getInputs().size() - operandCount + 1); - for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { - if (operandIndex == startOperandIndex) - newInputs.push_back(packedValue); - if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount) - newInputs.push_back(operand); - } - if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) { - rewriter.replaceOp(concatOp, newInputs.front()); - return; - } - rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); }); -} - -static RankedTensorType -getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) { - auto firstType = dyn_cast(concatOp.getInputs()[startOperandIndex].getType()); - if (!firstType || !firstType.hasStaticShape()) - return {}; - - int64_t axis = concatOp.getAxis(); - if (axis < 0 || axis >= firstType.getRank()) - return {}; - - SmallVector shape(firstType.getShape().begin(), firstType.getShape().end()); - shape[axis] = 0; - for (unsigned index = 0; index < operandCount; ++index) { - auto operandType = dyn_cast(concatOp.getInputs()[startOperandIndex + index].getType()); - if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank()) - return {}; - - for (int64_t dim = 0; dim < firstType.getRank(); ++dim) { - if (dim == axis) - continue; - if (operandType.getShape()[dim] != shape[dim]) - return {}; - } - - shape[axis] += operandType.getShape()[axis]; - } - - return RankedTensorType::get(shape, firstType.getElementType()); -} - -static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { - if (values.empty()) - return false; - - auto firstResult = dyn_cast(values.front()); - if (!firstResult) - return false; - - owner = firstResult.getOwner(); - startIndex = firstResult.getResultNumber(); - for (auto [index, value] : llvm::enumerate(values)) { - auto result = dyn_cast(value); - if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index) - return false; - } - - return true; -} - -static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) { - if (values.empty()) - return {}; - if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc)) - return packedSlice; - - Operation* owner = nullptr; - unsigned startIndex = 0; - if (getContiguousOpResults(values, owner, startIndex)) - if (auto extractRowsOp = dyn_cast(owner)) - return createPackedExtractRowsSlice( - extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); - - auto firstType = dyn_cast(values.front().getType()); - if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0) - return {}; - if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; })) - return {}; - return {}; -} - -static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { - return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand - && lhs.resultType == rhs.resultType; -} - -static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) { - if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType() - || lhs.steps.size() != rhs.steps.size()) { - return false; - } - - return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps), - [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); -} - -static bool isForwardedChannelPayload(Value value, Block& block) { - Operation* op = value.getDefiningOp(); - if (!op || op->getBlock() != &block) - return true; - - if (auto extractSliceOp = dyn_cast(op)) - return isForwardedChannelPayload(extractSliceOp.getSource(), block); - - return isa(op); -} - -static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { - RegularChunk chunk; - chunk.startOp = startOp.getOperation(); - chunk.input = startOp.getInput(); - chunk.output = startOp.getOutput(); - chunk.ops.push_back(startOp.getOperation()); - chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()}); - - Value currentValue = startOp.getOutput(); - while (currentValue.hasOneUse()) { - Operation* user = *currentValue.getUsers().begin(); - if (user->getBlock() != startOp->getBlock()) - break; - - auto vaddOp = dyn_cast(user); - if (!vaddOp) - break; - - if (vaddOp.getLhs() == currentValue) - chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()}); - else if (vaddOp.getRhs() == currentValue) - chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()}); - else - break; - - chunk.ops.push_back(vaddOp); - chunk.output = vaddOp.getOutput(); - currentValue = vaddOp.getOutput(); - } - - return chunk; -} - -static RegularCompactionResult -compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run, OperationFolder& constantFolder) { - assert(!run.empty() && "expected a non-empty regular chunk run"); - const RegularChunk& anchorChunk = run.front(); - RegularCompactionResult result; - - SmallVector inputs; - inputs.reserve(run.size()); - for (const RegularChunk& chunk : run) - inputs.push_back(chunk.input); - - rewriter.setInsertionPoint(anchorChunk.startOp); - Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc()); - if (!packedInput) - return result; - - auto inputType = cast(anchorChunk.input.getType()); - auto outputType = cast(anchorChunk.output.getType()); - auto packedOutputType = getPackedTensorType(outputType, static_cast(run.size())); - auto packedInit = tensor::EmptyOp::create( - rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType()); - auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder); - auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast(run.size()), constantFolder); - auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder); - auto loop = - scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); - - { - OpBuilder::InsertionGuard guard(rewriter); - Block* loopBlock = loop.getBody(); - rewriter.setInsertionPointToStart(loopBlock); - Value iv = loopBlock->getArgument(0); - Value acc = loopBlock->getArgument(1); - - Value inputRowOffset = iv; - if (inputType.getDimSize(0) != 1) { - auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder); - inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); - } - - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.push_back(inputRowOffset); - extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0))); - extractStrides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { - extractOffsets.push_back(rewriter.getIndexAttr(0)); - extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); - extractStrides.push_back(rewriter.getIndexAttr(1)); - } - auto inputSlice = tensor::ExtractSliceOp::create( - rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides); - - IRMapping mapping; - mapping.map(anchorChunk.input, inputSlice.getResult()); - for (Operation* op : anchorChunk.ops) { - Operation* cloned = rewriter.clone(*op, mapping); - for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) - mapping.map(oldResult, newResult); - } - - Value mappedOutput = mapping.lookup(anchorChunk.output); - Value outputRowOffset = iv; - if (outputType.getDimSize(0) != 1) { - auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder); - outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); - } - - SmallVector insertOffsets; - SmallVector insertSizes; - SmallVector insertStrides; - insertOffsets.push_back(outputRowOffset); - insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0))); - insertStrides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < outputType.getRank(); ++dim) { - insertOffsets.push_back(rewriter.getIndexAttr(0)); - insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim))); - insertStrides.push_back(rewriter.getIndexAttr(1)); - } - - auto inserted = tensor::InsertSliceOp::create( - rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides); - scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult()); - } - - SmallVector outputs; - outputs.reserve(run.size()); - for (const RegularChunk& chunk : run) - outputs.push_back(chunk.output); - - unsigned concatStartIndex = 0; - auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex); - auto concatPackedType = concatOp - ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(outputs.size())) - : RankedTensorType {}; - - if (concatOp && concatPackedType == packedOutputType) { - replaceConcatRunWithPackedValue( - concatOp, concatStartIndex, static_cast(outputs.size()), loop.getResult(0), rewriter); - } - else { - for (auto [index, chunk] : llvm::enumerate(run)) { - Value replacement = extractPackedChunk( - loop.getResult(0), outputType, static_cast(index), rewriter, chunk.startOp->getLoc()); - Value output = chunk.output; - output.replaceAllUsesWith(replacement); - } - } - - SmallVector opsToErase; - for (const RegularChunk& chunk : run) - llvm::append_range(opsToErase, chunk.ops); - for (Operation* op : llvm::reverse(opsToErase)) - rewriter.eraseOp(op); - - result.changed = true; - result.resumeAfter = loop.getOperation()->getNextNode(); - return result; -} - -} // namespace - -void orderBilateralChannelOps(func::FuncOp funcOp) { - for (auto compute : funcOp.getOps()) { - auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName); - if (!coreIdAttr) - continue; - - int32_t coreId = static_cast(coreIdAttr.getInt()); - Block& block = compute.getBody().front(); - SmallVector> moves; - DenseMap firstForwardedSendByEndpoint; - Operation* firstForwardedSend = nullptr; - - for (Operation& op : block) { - if (auto sendOp = dyn_cast(&op)) { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId) - && sourceCoreId == static_cast(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) { - if (!firstForwardedSend) - firstForwardedSend = sendOp.getOperation(); - uint64_t key = getEndpointKey(sourceCoreId, targetCoreId); - firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation()); - } - continue; - } - - auto receiveOp = dyn_cast(&op); - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId) - || targetCoreId != static_cast(coreId) || sourceCoreId >= static_cast(coreId)) { - continue; - } - - uint64_t key = getEndpointKey(static_cast(coreId), sourceCoreId); - auto firstMatchingSend = firstForwardedSendByEndpoint.find(key); - if (firstMatchingSend != firstForwardedSendByEndpoint.end()) - moves.push_back({receiveOp, firstMatchingSend->second}); - else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp)) - moves.push_back({receiveOp, firstForwardedSend}); - } - - for (auto [receiveOp, insertionPoint] : moves) - moveScalarChannelBundleBefore(receiveOp, insertionPoint); - - for (auto it = block.begin(); it != block.end();) { - auto receiveOp = dyn_cast(&*it); - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId) - || sourceCoreId >= static_cast(coreId)) { - ++it; - continue; - } - - Type outputType = receiveOp.getOutput().getType(); - auto run = collectConsecutiveRun( - it, block.end(), [&](spatial::SpatChannelReceiveOp current) { - uint64_t currentChannelId = 0; - uint32_t currentSourceCoreId = 0; - uint32_t currentTargetCoreId = 0; - return current.getOutput().getType() == outputType - && getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId) - && currentSourceCoreId < static_cast(coreId); - }); - - if (run.ops.size() > 1) { - SmallVector sorted(run.ops); - llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) { - uint64_t lhsChannelId = 0; - uint32_t lhsSourceCoreId = 0; - uint32_t lhsTargetCoreId = 0; - uint64_t rhsChannelId = 0; - uint32_t rhsSourceCoreId = 0; - uint32_t rhsTargetCoreId = 0; - bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId); - bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId); - if (!lhsHasMetadata || !rhsHasMetadata) - return false; - return lhsSourceCoreId > rhsSourceCoreId; - }); - Block::iterator insertIt = run.end; - for (auto op : sorted) - moveScalarChannelBundleBefore(op, &block, insertIt); - } - - it = run.end; - } - } -} - -void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { - IRRewriter rewriter(funcOp.getContext()); - OperationFolder constantFolder(funcOp.getContext()); - - for (auto compute : funcOp.getOps()) { - Block& block = compute.getBody().front(); - for (auto it = block.begin(); it != block.end();) { - auto receiveOp = dyn_cast(&*it); - if (receiveOp) { - Type outputType = receiveOp.getOutput().getType(); - auto run = collectConsecutiveRun( - it, block.end(), [&](spatial::SpatChannelReceiveOp current) { - return current.getOutput().getType() == outputType; - }); - - bool hasRepeatedEndpoint = false; - DenseSet seenEndpoints; - for (auto op : run.ops) { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { - hasRepeatedEndpoint = true; - break; - } - uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId); - if (!seenEndpoints.insert(endpointKey).second) { - hasRepeatedEndpoint = true; - break; - } - } - - if (run.ops.size() > 1 && !hasRepeatedEndpoint) { - struct ReceiveEntry { - spatial::SpatChannelReceiveOp op; - size_t originalIndex = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.ops.size()); - for (auto [originalIndex, op] : llvm::enumerate(run.ops)) { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { - sortedEntries.clear(); - break; - } - sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId}); - } - if (sortedEntries.empty()) { - ++it; - continue; - } - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - for (ReceiveEntry& entry : sortedEntries) { - channelIds.push_back(static_cast(entry.channelId)); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - } - - auto rowType = cast(run.ops.front().getOutput().getType()); - auto fallbackPackedType = getPackedTensorType(rowType, static_cast(sortedEntries.size())); - SmallVector sortedOutputs; - sortedOutputs.reserve(sortedEntries.size()); - for (ReceiveEntry& entry : sortedEntries) - sortedOutputs.push_back(entry.op.getOutput()); - - unsigned concatStartIndex = 0; - auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex); - auto concatPackedType = - concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(sortedOutputs.size())) - : RankedTensorType {}; - auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; - rewriter.setInsertionPoint(run.ops.front()); - SmallVector channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); - SmallVector sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); - SmallVector targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); - auto compactReceive = spatial::SpatChannelReceiveTensorOp::create( - rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues); - if (concatOp && concatPackedType) { - replaceConcatRunWithPackedValue(concatOp, - concatStartIndex, - static_cast(sortedOutputs.size()), - compactReceive.getOutput(), - rewriter); - } - else { - for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) - entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( - compactReceive.getOutput(), rowType, static_cast(sortedIndex), rewriter, entry.op.getLoc())); - } - for (auto op : run.ops) - rewriter.eraseOp(op); - - it = compactReceive->getIterator(); - ++it; - continue; - } - } - - auto sendOp = dyn_cast(&*it); - if (sendOp) { - Type inputType = sendOp.getInput().getType(); - auto run = - collectConsecutiveRun(it, block.end(), [&](spatial::SpatChannelSendOp current) { - return current.getInput().getType() == inputType; - }); - - if (run.ops.size() > 1) { - struct SendEntry { - spatial::SpatChannelSendOp op; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.ops.size()); - for (auto op : run.ops) { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { - sortedEntries.clear(); - break; - } - sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId}); - } - if (sortedEntries.empty()) { - ++it; - continue; - } - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector inputs; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - inputs.reserve(sortedEntries.size()); - for (SendEntry& entry : sortedEntries) { - channelIds.push_back(static_cast(entry.channelId)); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - inputs.push_back(entry.op.getInput()); - } - - rewriter.setInsertionPoint(run.ops.front()); - Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); - if (packedInput) { - SmallVector channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); - SmallVector sourceCoreIdValues = - createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); - SmallVector targetCoreIdValues = - createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); - spatial::SpatChannelSendTensorOp::create( - rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput); - for (auto op : run.ops) - rewriter.eraseOp(op); - - it = run.end; - continue; - } - } - } - - ++it; - } - } -} - -void compactBatchChannelRuns(func::FuncOp funcOp) { - IRRewriter rewriter(funcOp.getContext()); - - for (auto batch : funcOp.getOps()) { - Block& block = batch.getBody().front(); - for (auto it = block.begin(); it != block.end();) { - auto receiveOp = dyn_cast(&*it); - if (receiveOp) { - Type outputType = receiveOp.getOutput().getType(); - auto run = collectConsecutiveRun( - it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) { - return current.getOutput().getType() == outputType; - }); - - if (run.ops.size() > 1) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - for (auto op : run.ops) { - llvm::append_range(channelIds, op.getChannelIds()); - llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); - llvm::append_range(targetCoreIds, op.getTargetCoreIds()); - } - - auto rowType = cast(run.ops.front().getOutput().getType()); - auto fallbackPackedType = getPackedTensorType(rowType, static_cast(run.ops.size())); - SmallVector outputs; - outputs.reserve(run.ops.size()); - for (auto op : run.ops) - outputs.push_back(op.getOutput()); - - unsigned concatStartIndex = 0; - auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex); - auto concatPackedType = - concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(outputs.size())) - : RankedTensorType {}; - auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; - rewriter.setInsertionPoint(run.ops.front()); - auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create( - rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds); - if (concatOp && concatPackedType) { - replaceConcatRunWithPackedValue( - concatOp, concatStartIndex, static_cast(outputs.size()), compactReceive.getOutput(), rewriter); - } - else { - for (auto [index, op] : llvm::enumerate(run.ops)) - op.getOutput().replaceAllUsesWith(extractPackedChunk( - compactReceive.getOutput(), rowType, static_cast(index), rewriter, op.getLoc())); - } - for (auto op : run.ops) - rewriter.eraseOp(op); - - it = compactReceive->getIterator(); - ++it; - continue; - } - } - - auto sendOp = dyn_cast(&*it); - if (sendOp) { - Type inputType = sendOp.getInput().getType(); - auto run = collectConsecutiveRun( - it, block.end(), [&](spatial::SpatChannelSendBatchOp current) { - return current.getInput().getType() == inputType; - }); - - if (run.ops.size() > 1) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector inputs; - inputs.reserve(run.ops.size()); - for (auto op : run.ops) { - llvm::append_range(channelIds, op.getChannelIds()); - llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); - llvm::append_range(targetCoreIds, op.getTargetCoreIds()); - inputs.push_back(op.getInput()); - } - - rewriter.setInsertionPoint(run.ops.front()); - Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); - if (packedInput) { - spatial::SpatChannelSendTensorBatchOp::create( - rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput); - for (auto op : run.ops) - rewriter.eraseOp(op); - - it = run.end; - continue; - } - } - } - - ++it; - } - } -} - -void compactRegularOpRuns(func::FuncOp funcOp) { - IRRewriter rewriter(funcOp.getContext()); - OperationFolder constantFolder(funcOp.getContext()); - - auto compactInBlock = [&](Block& block) { - for (auto it = block.begin(); it != block.end();) { - auto startOp = dyn_cast(&*it); - if (!startOp) { - ++it; - continue; - } - - auto anchorChunk = analyzeRegularChunk(startOp); - if (failed(anchorChunk)) { - ++it; - continue; - } - - auto anchorEndIt = std::next(it, static_cast(anchorChunk->ops.size())); - SmallVector run {*anchorChunk}; - auto runIt = anchorEndIt; - while (runIt != block.end()) { - auto candidateStart = dyn_cast(&*runIt); - if (!candidateStart) - break; - - auto candidateChunk = analyzeRegularChunk(candidateStart); - if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk)) - break; - - run.push_back(*candidateChunk); - runIt = std::next(runIt, static_cast(candidateChunk->ops.size())); - } - - if (run.size() <= 1) { - it = anchorEndIt; - continue; - } - - size_t originalOpCount = 0; - for (const RegularChunk& chunk : run) - originalOpCount += chunk.ops.size(); - - RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder); - if (result.changed) { - assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run"); - if (!result.resumeAfter) { - it = block.end(); - continue; - } - it = result.resumeAfter->getIterator(); - continue; - } - - it = anchorEndIt; - } - }; - - for (auto compute : funcOp.getOps()) - compactInBlock(compute.getBody().front()); - for (auto batch : funcOp.getOps()) - compactInBlock(batch.getBody().front()); -} - -void compactRowWiseWvmmRuns(func::FuncOp funcOp) { - IRRewriter rewriter(funcOp.getContext()); - OperationFolder constantFolder(funcOp.getContext()); - - for (auto compute : funcOp.getOps()) { - Block& block = compute.getBody().front(); - for (auto it = block.begin(); it != block.end();) { - auto wvmmOp = dyn_cast(&*it); - if (!wvmmOp) { - ++it; - continue; - } - - auto extractRowsOp = wvmmOp.getInput().getDefiningOp(); - auto rowResult = dyn_cast(wvmmOp.getInput()); - auto outputType = dyn_cast(wvmmOp.getOutput().getType()); - if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType - || !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) { - ++it; - continue; - } - - int64_t expectedRow = static_cast(rowResult.getResultNumber()); - auto run = collectConsecutiveRun(it, block.end(), [&](spatial::SpatVMMOp current) { - if (current.getWeight() != wvmmOp.getWeight() - || current.getInput().getDefiningOp() != extractRowsOp - || current.getInput().getType() != wvmmOp.getInput().getType() - || current.getOutput().getType() != wvmmOp.getOutput().getType()) - return false; - - auto currentRow = dyn_cast(current.getInput()); - if (!currentRow || currentRow.getResultNumber() != static_cast(expectedRow)) - return false; - - ++expectedRow; - return true; - }); - - if (run.ops.size() <= 1) { - ++it; - continue; - } - - if (!run.ops.front().getOutput().hasOneUse()) { - ++it; - continue; - } - auto concatUse = run.ops.front().getOutput().getUses().begin(); - auto concatOp = dyn_cast(concatUse->getOwner()); - if (!concatOp) { - ++it; - continue; - } - - unsigned concatStartIndex = concatUse->getOperandNumber(); - bool validConcatRun = true; - for (auto [index, op] : llvm::enumerate(run.ops)) { - if (!op.getOutput().hasOneUse()) { - validConcatRun = false; - break; - } - OpOperand& use = *op.getOutput().getUses().begin(); - if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) { - validConcatRun = false; - break; - } - } - if (!validConcatRun) { - ++it; - continue; - } - - auto inputType = dyn_cast(wvmmOp.getInput().getType()); - auto sourceType = dyn_cast(extractRowsOp.getInput().getType()); - if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) { - ++it; - continue; - } - - int64_t inputCols = inputType.getShape()[1]; - int64_t outputCols = outputType.getShape()[1]; - if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) { - ++it; - continue; - } - - int64_t firstRow = static_cast(rowResult.getResultNumber()); - int64_t runLength = static_cast(run.ops.size()); - auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); - - rewriter.setInsertionPoint(run.ops.front()); - auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder); - auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder); - auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder); - auto packedInit = - tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType()); - auto loop = - scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); - - { - OpBuilder::InsertionGuard guard(rewriter); - Block* loopBlock = loop.getBody(); - rewriter.setInsertionPointToStart(loopBlock); - Value iv = loopBlock->getArgument(0); - Value acc = loopBlock->getArgument(1); - - Value sourceRow = iv; - if (firstRow != 0) { - auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder); - sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue); - } - - SmallVector extractOffsets = {sourceRow, rewriter.getIndexAttr(0)}; - SmallVector extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)}; - SmallVector extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto extractedRow = tensor::ExtractSliceOp::create(rewriter, - run.ops.front().getLoc(), - inputType, - extractRowsOp.getInput(), - extractOffsets, - extractSizes, - extractStrides); - auto loopWvmm = spatial::SpatVMMOp::create( - rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult()); - - SmallVector insertOffsets = {iv, rewriter.getIndexAttr(0)}; - SmallVector insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; - SmallVector insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto inserted = tensor::InsertSliceOp::create( - rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides); - scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult()); - } - - SmallVector newConcatInputs; - newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1); - for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { - if (operandIndex == concatStartIndex) - newConcatInputs.push_back(loop.getResult(0)); - if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size()) - newConcatInputs.push_back(operand); - } - rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); }); - for (auto op : run.ops) - rewriter.eraseOp(op); - - it = loop->getIterator(); - ++it; - } - } -} - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp deleted file mode 100644 index 79cdf09..0000000 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" - -#include - -namespace onnx_mlir { - -void orderBilateralChannelOps(mlir::func::FuncOp funcOp); -void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId); -void compactBatchChannelRuns(mlir::func::FuncOp funcOp); -void compactRegularOpRuns(mlir::func::FuncOp funcOp); -void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp); - -} // namespace onnx_mlir