better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-23 11:17:36 +02:00
parent 7f3c7464b4
commit 76a37e198f
7 changed files with 291 additions and 1648 deletions
-2
View File
@@ -9,8 +9,6 @@ add_pim_library(SpatialOps
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -557,22 +558,6 @@ Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t v
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder);
}
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
SmallVector<Value, 8> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(createIndexConstant(state, anchor, value));
return constants;
}
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
SmallVector<int64_t, 8> widened;
widened.reserve(values.size());
for (int32_t value : values)
widened.push_back(value);
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
}
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
SmallVector<APInt, 8> elements;
elements.reserve(values.size());
@@ -644,6 +629,28 @@ Value createLaneIndexedIndexValue(MaterializerState& state,
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
}
Value createIndexedIndexValue(
MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values, Value index, Location loc) {
assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values))
return createIndexConstant(state, anchor, values.front());
Value table = createIndexTensorConstant(state, anchor, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
}
Value createIndexedIndexValue(
MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values, Value index, Location loc) {
assert(!values.empty() && "expected at least one indexed value");
if (allEqual(values))
return createIndexConstant(state, anchor, values.front());
Value table = createIndexTensorConstant(state, anchor, values);
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
}
FailureOr<SmallVector<ComputeInstance, 8>>
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
SmallVector<ComputeInstance, 8> peers;
@@ -808,6 +815,53 @@ ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey ke
return it->second;
}
void appendScalarSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
}
void appendScalarSendLoop(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class");
assert(channelIds.size() > 1 && "send loop is only useful for multiple sends");
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, sourceClass.op, 1);
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
OpBuilder::InsertionGuard guard(state.rewriter);
state.rewriter.setInsertionPointToStart(loop.getBody());
Value index = loop.getInductionVar();
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
}
void appendSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
@@ -819,9 +873,9 @@ void appendSend(MaterializerState& state,
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one send");
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (sourceClass.isBatch) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
@@ -829,12 +883,13 @@ void appendSend(MaterializerState& state,
return;
}
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
if (channelIds.size() == 1) {
appendScalarSend(
state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
return;
}
appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
}
Value appendScalarReceive(MaterializerState& state,
@@ -911,6 +966,212 @@ Value appendPackedScalarReceives(MaterializerState& state,
return packed;
}
std::optional<int64_t> getConstantIndexValue(Value value) {
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
APInt constantValue;
if (matchPattern(value, m_ConstantInt(&constantValue)))
return constantValue.getSExtValue();
return std::nullopt;
}
bool getReceiveMetadata(SpatChannelReceiveOp receive,
int64_t& channelId,
int64_t& sourceCoreId,
int64_t& targetCoreId) {
// SpatChannelReceiveOp operands are: channel, source core, target core.
std::optional<int64_t> channel = getConstantIndexValue(receive->getOperand(0));
std::optional<int64_t> source = getConstantIndexValue(receive->getOperand(1));
std::optional<int64_t> target = getConstantIndexValue(receive->getOperand(2));
if (!channel || !source || !target)
return false;
channelId = *channel;
sourceCoreId = *source;
targetCoreId = *target;
return true;
}
bool hasCompatibleConcatTypes(RankedTensorType concatType, RankedTensorType fragmentType, size_t fragmentCount) {
if (!concatType.hasStaticShape() || !fragmentType.hasStaticShape())
return false;
if (concatType.getRank() != fragmentType.getRank())
return false;
if (concatType.getRank() == 0)
return false;
if (concatType.getElementType() != fragmentType.getElementType())
return false;
if (concatType.getDimSize(0) != fragmentType.getDimSize(0) * static_cast<int64_t>(fragmentCount))
return false;
for (int64_t dim = 1; dim < concatType.getRank(); ++dim)
if (concatType.getDimSize(dim) != fragmentType.getDimSize(dim))
return false;
return true;
}
Value createReceiveConcatLoop(MaterializerState& state,
Operation* anchor,
Operation* insertionPoint,
RankedTensorType concatType,
RankedTensorType fragmentType,
ArrayRef<int64_t> channelIds,
ArrayRef<int64_t> sourceCoreIds,
ArrayRef<int64_t> targetCoreIds,
Location loc) {
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
assert(!channelIds.empty() && "expected at least one receive");
Value lowerBound = createIndexConstant(state, anchor, 0);
Value upperBound = createIndexConstant(state, anchor, static_cast<int64_t>(channelIds.size()));
Value step = createIndexConstant(state, anchor, 1);
state.rewriter.setInsertionPoint(insertionPoint);
Value init =
tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult();
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
Block* body = loop.getBody();
if (!body->empty())
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
state.rewriter.eraseOp(yield);
OpBuilder::InsertionGuard guard(state.rewriter);
state.rewriter.setInsertionPointToEnd(body);
Value index = loop.getInductionVar();
Value acc = body->getArgument(1);
Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc);
Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc);
Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc);
Value received =
SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput();
Value firstOffset = index;
if (fragmentType.getDimSize(0) != 1) {
Value rowsPerFragment = createIndexConstant(state, anchor, fragmentType.getDimSize(0));
firstOffset = arith::MulIOp::create(state.rewriter, loc, index, rowsPerFragment).getResult();
}
SmallVector<OpFoldResult, 4> offsets;
SmallVector<OpFoldResult, 4> sizes;
SmallVector<OpFoldResult, 4> strides;
offsets.reserve(fragmentType.getRank());
sizes.reserve(fragmentType.getRank());
strides.reserve(fragmentType.getRank());
offsets.push_back(firstOffset);
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(0)));
strides.push_back(state.rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < fragmentType.getRank(); ++dim) {
offsets.push_back(state.rewriter.getIndexAttr(0));
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(dim)));
strides.push_back(state.rewriter.getIndexAttr(1));
}
Value next = tensor::InsertSliceOp::create(state.rewriter, loc, received, acc, offsets, sizes, strides).getResult();
scf::YieldOp::create(state.rewriter, loc, next);
return loop.getResult(0);
}
bool compactReceiveConcat(MaterializerState& state, MaterializedClass& materializedClass, tensor::ConcatOp concat) {
auto dimAttr = concat->getAttrOfType<IntegerAttr>("dim");
if (!dimAttr || dimAttr.getInt() != 0)
return false;
OperandRange inputs = concat->getOperands();
if (inputs.size() < 2)
return false;
SmallVector<SpatChannelReceiveOp, 8> receives;
receives.reserve(inputs.size());
for (Value input : inputs) {
auto receive = input.getDefiningOp<SpatChannelReceiveOp>();
if (!receive)
return false;
if (receive->getBlock() != concat->getBlock())
return false;
if (!receive->getResult(0).hasOneUse())
return false;
receives.push_back(receive);
}
Operation* expected = concat.getOperation();
for (SpatChannelReceiveOp receive : llvm::reverse(receives)) {
Operation* previous = expected->getPrevNode();
if (previous != receive.getOperation())
return false;
expected = previous;
}
auto concatType = dyn_cast<RankedTensorType>(concat->getResult(0).getType());
auto fragmentType = dyn_cast<RankedTensorType>(receives.front()->getResult(0).getType());
if (!concatType || !fragmentType)
return false;
if (!hasCompatibleConcatTypes(concatType, fragmentType, receives.size()))
return false;
SmallVector<int64_t, 8> channelIds;
SmallVector<int64_t, 8> sourceCoreIds;
SmallVector<int64_t, 8> targetCoreIds;
channelIds.reserve(receives.size());
sourceCoreIds.reserve(receives.size());
targetCoreIds.reserve(receives.size());
for (SpatChannelReceiveOp receive : receives) {
if (receive->getResult(0).getType() != fragmentType)
return false;
int64_t channelId = 0;
int64_t sourceCoreId = 0;
int64_t targetCoreId = 0;
if (!getReceiveMetadata(receive, channelId, sourceCoreId, targetCoreId))
return false;
channelIds.push_back(channelId);
sourceCoreIds.push_back(sourceCoreId);
targetCoreIds.push_back(targetCoreId);
}
Value replacement = createReceiveConcatLoop(state,
materializedClass.op,
receives.front().getOperation(),
concatType,
fragmentType,
channelIds,
sourceCoreIds,
targetCoreIds,
concat.getLoc());
concat->getResult(0).replaceAllUsesWith(replacement);
state.rewriter.eraseOp(concat.getOperation());
for (SpatChannelReceiveOp receive : llvm::reverse(receives))
state.rewriter.eraseOp(receive.getOperation());
return true;
}
void compactReceiveConcats(MaterializerState& state) {
SmallVector<std::pair<MaterializedClass*, tensor::ConcatOp>, 16> concatOps;
for (MaterializedClass& materializedClass : state.classes)
materializedClass.op->walk([&](tensor::ConcatOp concat) { concatOps.push_back({&materializedClass, concat}); });
for (auto [materializedClass, concat] : concatOps)
compactReceiveConcat(state, *materializedClass, concat);
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
@@ -1086,15 +1347,8 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
return success();
}
LogicalResult emitHostCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
(void) keys;
(void) loc;
LogicalResult
emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) {
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
return success();
@@ -1115,7 +1369,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
if (failed(
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure();
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
return failure();
state.availableValues[keys.front()][sourceClass.id] = payload;
return success();
@@ -1129,7 +1383,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure();
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
return failure();
for (ProducerKey key : keys)
@@ -1146,7 +1400,7 @@ FailureOr<Value> materializeWholeBatchInput(
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
uint32_t batchLaneCount = batch.getLaneCount();
SmallVector<Value, 8> fragments;
uint32_t lane = 0;
@@ -1426,6 +1680,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
if (failed(materializeInstanceSlot(state, instance)))
return failure();
compactReceiveConcats(state);
replaceHostUses(state);
if (failed(eraseOldComputeOps(state)))
return failure();
@@ -27,7 +27,6 @@
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <iterator>
#include <memory>
#include <optional>
#include <tuple>
@@ -35,7 +34,6 @@
#include <vector>
#include "MaterializeMergeSchedule.hpp"
#include "PostMergeCompaction.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
@@ -660,13 +658,6 @@ public:
emitMergeIrCounts("after-materialization", func);
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
signalPassFailure();
return;
}
emitMergeIrCounts("after-post-merge-compaction", func);*/
{
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
if (!sortTopologically(&func.getBody().front())) {
@@ -1,532 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <chrono>
#include <cstdlib>
#include <limits>
#include <optional>
#include "PostMergeCompaction.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
return std::nullopt;
}
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
static FailureOr<int64_t> getConstantI64Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int32_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
return static_cast<uint64_t>(phaseAttr.getInt());
return std::nullopt;
}
struct RebatchKey {
unsigned inputCount = 0;
unsigned resultCount = 0;
unsigned weightCount = 0;
uint64_t phase = 0;
bool hasPhase = false;
uint64_t structureHash = 0;
bool operator==(const RebatchKey& other) const {
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
}
};
struct RebatchKeyInfo {
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
static unsigned getHashValue(const RebatchKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
}
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
};
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
RebatchKey computeRebatchKey(SpatCompute compute) {
llvm::hash_code structureHash =
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
for (Value weight : compute.getWeights())
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
structureHash = llvm::hash_combine(structureHash, *phase);
Block& body = compute.getBody().front();
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
for (BlockArgument arg : body.getArguments())
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
for (Operation& op : body) {
structureHash = llvm::hash_combine(
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
for (Type type : op.getResultTypes())
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
for (NamedAttribute attr : op.getAttrs())
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
}
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
return {static_cast<unsigned>(compute.getInputs().size()),
static_cast<unsigned>(compute.getResultTypes().size()),
static_cast<unsigned>(compute.getWeights().size()),
phase.value_or(0),
phase.has_value(),
static_cast<uint64_t>(structureHash)};
}
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
if (!lhs || !rhs)
return false;
if (lhs.getInputs().size() != rhs.getInputs().size())
return false;
if (lhs.getResultTypes() != rhs.getResultTypes())
return false;
if (lhs.getWeights().size() != rhs.getWeights().size())
return false;
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
return false;
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
return false;
auto& lhsBlock = lhs.getBody().front();
auto& rhsBlock = rhs.getBody().front();
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
return false;
DenseMap<Value, Value> mappedValues;
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
if (lhsArg.getType() != rhsArg.getType())
return false;
mappedValues[lhsArg] = rhsArg;
}
auto lhsIt = lhsBlock.begin();
auto rhsIt = rhsBlock.begin();
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
Operation& lhsOp = *lhsIt;
Operation& rhsOp = *rhsIt;
if (lhsOp.getName() != rhsOp.getName())
return false;
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
return false;
if (lhsOp.getNumResults() != rhsOp.getNumResults())
return false;
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
return false;
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
auto mapped = mappedValues.find(lhsOperand);
if (mapped != mappedValues.end()) {
if (mapped->second != rhsOperand)
return false;
continue;
}
if (lhsOperand != rhsOperand)
return false;
}
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
return false;
}
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
return false;
}
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
return false;
}
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
return false;
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
mappedValues[lhsResult] = rhsResult;
}
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseSet<Operation*> consumed;
DenseMap<Operation*, size_t> computeOrder;
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
for (auto [index, compute] : llvm::enumerate(computes)) {
computeOrder[compute.getOperation()] = index;
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
}
for (size_t index = 0; index < computes.size(); ++index) {
auto anchor = computes[index];
if (consumed.contains(anchor))
continue;
if (anchor.getInputs().size() > 1)
continue;
if (!anchor.getResults().empty())
continue;
SmallVector<SpatCompute> group {anchor};
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
if (auto coreId = getComputeCoreId(anchor))
usedCoreIds.insert(*coreId);
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
if (bucketIt == candidatesByKey.end())
continue;
for (auto candidate : bucketIt->second) {
if (computeOrder.lookup(candidate.getOperation()) <= index)
continue;
if (consumed.contains(candidate))
continue;
if (!areEquivalentForRebatch(anchor, candidate))
continue;
if (auto coreId = getComputeCoreId(candidate))
if (!usedCoreIds.insert(*coreId).second)
continue;
group.push_back(candidate);
}
if (group.size() <= 1)
continue;
auto insertionAnchor = group.front();
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
llvm::stable_sort(
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
}
SmallVector<Value> weights;
weights.reserve(group.size() * anchor.getWeights().size());
SmallVector<Value> inputs;
inputs.reserve(group.size() * anchor.getInputs().size());
SmallVector<int32_t> coreIds;
coreIds.reserve(group.size());
bool haveAllCoreIds = true;
for (auto compute : group) {
llvm::append_range(weights, compute.getWeights());
llvm::append_range(inputs, compute.getInputs());
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
haveAllCoreIds = false;
else if (haveAllCoreIds)
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
}
rewriter.setInsertionPoint(insertionAnchor);
auto rebatched = SpatComputeBatch::create(rewriter,
insertionAnchor.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
ValueRange(weights),
ValueRange(inputs));
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
auto* newBlock =
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(newBlock);
IRMapping mapper;
auto& anchorBlock = anchor.getBody().front();
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
for (Operation& anchorOp : anchorBlock) {
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
BatchReceiveEntry entry;
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchReceiveEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getOutput().getType(),
channelIdValues,
sourceCoreIdValues,
targetCoreIdValues);
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
continue;
}
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
BatchSendEntry entry;
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchSendEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
sendOp.getLoc(),
channelIdValues,
sourceCoreIdValues,
targetCoreIdValues,
mapper.lookup(sendOp.getInput()));
continue;
}
if (isa<spatial::SpatYieldOp>(anchorOp)) {
for (auto& opIt : opIts)
++opIt;
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
continue;
}
Operation* cloned = rewriter.clone(anchorOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
for (auto& opIt : opIts)
++opIt;
}
for (auto compute : group) {
compute->removeAttr(kRebatchPhaseAttrName);
consumed.insert(compute);
rewriter.eraseOp(compute);
}
}
for (auto compute : funcOp.getOps<SpatCompute>())
compute->removeAttr(kRebatchPhaseAttrName);
}
void cleanupDeadPackingOps(func::FuncOp funcOp) {
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
} // namespace
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
{
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
orderBilateralChannelOps(funcOp);
}
{
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
rebatchEquivalentComputes(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-regular-op-runs");
compactRegularOpRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
compactRowWiseWvmmRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp);
}
return success();
}
} // namespace onnx_mlir
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -1,15 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include <cstdint>
namespace onnx_mlir {
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir