better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -9,8 +9,6 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||
|
||||
+291
-35
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@@ -557,22 +558,6 @@ Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t v
|
||||
return getOrCreateHostIndexConstant(anchor, value, state.constantFolder);
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||
SmallVector<Value, 8> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(createIndexConstant(state, anchor, value));
|
||||
return constants;
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values) {
|
||||
SmallVector<int64_t, 8> widened;
|
||||
widened.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
widened.push_back(value);
|
||||
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
||||
}
|
||||
|
||||
Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values) {
|
||||
SmallVector<APInt, 8> elements;
|
||||
elements.reserve(values.size());
|
||||
@@ -644,6 +629,28 @@ Value createLaneIndexedIndexValue(MaterializerState& state,
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||
}
|
||||
|
||||
Value createIndexedIndexValue(
|
||||
MaterializerState& state, Operation* anchor, ArrayRef<int64_t> values, Value index, Location loc) {
|
||||
assert(!values.empty() && "expected at least one indexed value");
|
||||
|
||||
if (allEqual(values))
|
||||
return createIndexConstant(state, anchor, values.front());
|
||||
|
||||
Value table = createIndexTensorConstant(state, anchor, values);
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||
}
|
||||
|
||||
Value createIndexedIndexValue(
|
||||
MaterializerState& state, Operation* anchor, ArrayRef<int32_t> values, Value index, Location loc) {
|
||||
assert(!values.empty() && "expected at least one indexed value");
|
||||
|
||||
if (allEqual(values))
|
||||
return createIndexConstant(state, anchor, values.front());
|
||||
|
||||
Value table = createIndexTensorConstant(state, anchor, values);
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult();
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<ComputeInstance, 8>>
|
||||
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
||||
SmallVector<ComputeInstance, 8> peers;
|
||||
@@ -808,6 +815,53 @@ ArrayRef<ClassId> getDestinationClasses(MaterializerState& state, ProducerKey ke
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void appendScalarSend(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
int64_t channelId,
|
||||
int32_t sourceCoreId,
|
||||
int32_t targetCoreId,
|
||||
Location loc) {
|
||||
assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class");
|
||||
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
|
||||
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
|
||||
Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId);
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload);
|
||||
}
|
||||
|
||||
void appendScalarSendLoop(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class");
|
||||
assert(channelIds.size() > 1 && "send loop is only useful for multiple sends");
|
||||
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
|
||||
Value lowerBound = createIndexConstant(state, sourceClass.op, 0);
|
||||
Value upperBound = createIndexConstant(state, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||
Value step = createIndexConstant(state, sourceClass.op, 1);
|
||||
|
||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||
|
||||
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||
state.rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value index = loop.getInductionVar();
|
||||
Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc);
|
||||
Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc);
|
||||
Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc);
|
||||
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||
}
|
||||
|
||||
void appendSend(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
@@ -819,9 +873,9 @@ void appendSend(MaterializerState& state,
|
||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||
assert(!channelIds.empty() && "expected at least one send");
|
||||
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
|
||||
if (sourceClass.isBatch) {
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
|
||||
Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc);
|
||||
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc);
|
||||
Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc);
|
||||
@@ -829,12 +883,13 @@ void appendSend(MaterializerState& state,
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto index : llvm::seq<size_t>(0, channelIds.size())) {
|
||||
Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]);
|
||||
Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]);
|
||||
Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]);
|
||||
SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload);
|
||||
if (channelIds.size() == 1) {
|
||||
appendScalarSend(
|
||||
state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc);
|
||||
return;
|
||||
}
|
||||
|
||||
appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
}
|
||||
|
||||
Value appendScalarReceive(MaterializerState& state,
|
||||
@@ -911,6 +966,212 @@ Value appendPackedScalarReceives(MaterializerState& state,
|
||||
return packed;
|
||||
}
|
||||
|
||||
std::optional<int64_t> getConstantIndexValue(Value value) {
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||
return constant.value();
|
||||
|
||||
APInt constantValue;
|
||||
if (matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return constantValue.getSExtValue();
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool getReceiveMetadata(SpatChannelReceiveOp receive,
|
||||
int64_t& channelId,
|
||||
int64_t& sourceCoreId,
|
||||
int64_t& targetCoreId) {
|
||||
// SpatChannelReceiveOp operands are: channel, source core, target core.
|
||||
std::optional<int64_t> channel = getConstantIndexValue(receive->getOperand(0));
|
||||
std::optional<int64_t> source = getConstantIndexValue(receive->getOperand(1));
|
||||
std::optional<int64_t> target = getConstantIndexValue(receive->getOperand(2));
|
||||
if (!channel || !source || !target)
|
||||
return false;
|
||||
|
||||
channelId = *channel;
|
||||
sourceCoreId = *source;
|
||||
targetCoreId = *target;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool hasCompatibleConcatTypes(RankedTensorType concatType, RankedTensorType fragmentType, size_t fragmentCount) {
|
||||
if (!concatType.hasStaticShape() || !fragmentType.hasStaticShape())
|
||||
return false;
|
||||
if (concatType.getRank() != fragmentType.getRank())
|
||||
return false;
|
||||
if (concatType.getRank() == 0)
|
||||
return false;
|
||||
if (concatType.getElementType() != fragmentType.getElementType())
|
||||
return false;
|
||||
|
||||
if (concatType.getDimSize(0) != fragmentType.getDimSize(0) * static_cast<int64_t>(fragmentCount))
|
||||
return false;
|
||||
|
||||
for (int64_t dim = 1; dim < concatType.getRank(); ++dim)
|
||||
if (concatType.getDimSize(dim) != fragmentType.getDimSize(dim))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createReceiveConcatLoop(MaterializerState& state,
|
||||
Operation* anchor,
|
||||
Operation* insertionPoint,
|
||||
RankedTensorType concatType,
|
||||
RankedTensorType fragmentType,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int64_t> sourceCoreIds,
|
||||
ArrayRef<int64_t> targetCoreIds,
|
||||
Location loc) {
|
||||
assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch");
|
||||
assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch");
|
||||
assert(!channelIds.empty() && "expected at least one receive");
|
||||
|
||||
Value lowerBound = createIndexConstant(state, anchor, 0);
|
||||
Value upperBound = createIndexConstant(state, anchor, static_cast<int64_t>(channelIds.size()));
|
||||
Value step = createIndexConstant(state, anchor, 1);
|
||||
|
||||
state.rewriter.setInsertionPoint(insertionPoint);
|
||||
Value init =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult();
|
||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init});
|
||||
|
||||
Block* body = loop.getBody();
|
||||
if (!body->empty())
|
||||
if (auto yield = dyn_cast<scf::YieldOp>(body->back()))
|
||||
state.rewriter.eraseOp(yield);
|
||||
|
||||
OpBuilder::InsertionGuard guard(state.rewriter);
|
||||
state.rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value index = loop.getInductionVar();
|
||||
Value acc = body->getArgument(1);
|
||||
|
||||
Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc);
|
||||
Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc);
|
||||
Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc);
|
||||
|
||||
Value received =
|
||||
SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput();
|
||||
|
||||
Value firstOffset = index;
|
||||
if (fragmentType.getDimSize(0) != 1) {
|
||||
Value rowsPerFragment = createIndexConstant(state, anchor, fragmentType.getDimSize(0));
|
||||
firstOffset = arith::MulIOp::create(state.rewriter, loc, index, rowsPerFragment).getResult();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> offsets;
|
||||
SmallVector<OpFoldResult, 4> sizes;
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
offsets.reserve(fragmentType.getRank());
|
||||
sizes.reserve(fragmentType.getRank());
|
||||
strides.reserve(fragmentType.getRank());
|
||||
|
||||
offsets.push_back(firstOffset);
|
||||
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(0)));
|
||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||
|
||||
for (int64_t dim = 1; dim < fragmentType.getRank(); ++dim) {
|
||||
offsets.push_back(state.rewriter.getIndexAttr(0));
|
||||
sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(dim)));
|
||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
Value next = tensor::InsertSliceOp::create(state.rewriter, loc, received, acc, offsets, sizes, strides).getResult();
|
||||
scf::YieldOp::create(state.rewriter, loc, next);
|
||||
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
bool compactReceiveConcat(MaterializerState& state, MaterializedClass& materializedClass, tensor::ConcatOp concat) {
|
||||
auto dimAttr = concat->getAttrOfType<IntegerAttr>("dim");
|
||||
if (!dimAttr || dimAttr.getInt() != 0)
|
||||
return false;
|
||||
|
||||
OperandRange inputs = concat->getOperands();
|
||||
if (inputs.size() < 2)
|
||||
return false;
|
||||
|
||||
SmallVector<SpatChannelReceiveOp, 8> receives;
|
||||
receives.reserve(inputs.size());
|
||||
|
||||
for (Value input : inputs) {
|
||||
auto receive = input.getDefiningOp<SpatChannelReceiveOp>();
|
||||
if (!receive)
|
||||
return false;
|
||||
if (receive->getBlock() != concat->getBlock())
|
||||
return false;
|
||||
if (!receive->getResult(0).hasOneUse())
|
||||
return false;
|
||||
receives.push_back(receive);
|
||||
}
|
||||
|
||||
Operation* expected = concat.getOperation();
|
||||
for (SpatChannelReceiveOp receive : llvm::reverse(receives)) {
|
||||
Operation* previous = expected->getPrevNode();
|
||||
if (previous != receive.getOperation())
|
||||
return false;
|
||||
expected = previous;
|
||||
}
|
||||
|
||||
auto concatType = dyn_cast<RankedTensorType>(concat->getResult(0).getType());
|
||||
auto fragmentType = dyn_cast<RankedTensorType>(receives.front()->getResult(0).getType());
|
||||
if (!concatType || !fragmentType)
|
||||
return false;
|
||||
if (!hasCompatibleConcatTypes(concatType, fragmentType, receives.size()))
|
||||
return false;
|
||||
|
||||
SmallVector<int64_t, 8> channelIds;
|
||||
SmallVector<int64_t, 8> sourceCoreIds;
|
||||
SmallVector<int64_t, 8> targetCoreIds;
|
||||
channelIds.reserve(receives.size());
|
||||
sourceCoreIds.reserve(receives.size());
|
||||
targetCoreIds.reserve(receives.size());
|
||||
|
||||
for (SpatChannelReceiveOp receive : receives) {
|
||||
if (receive->getResult(0).getType() != fragmentType)
|
||||
return false;
|
||||
|
||||
int64_t channelId = 0;
|
||||
int64_t sourceCoreId = 0;
|
||||
int64_t targetCoreId = 0;
|
||||
if (!getReceiveMetadata(receive, channelId, sourceCoreId, targetCoreId))
|
||||
return false;
|
||||
|
||||
channelIds.push_back(channelId);
|
||||
sourceCoreIds.push_back(sourceCoreId);
|
||||
targetCoreIds.push_back(targetCoreId);
|
||||
}
|
||||
|
||||
Value replacement = createReceiveConcatLoop(state,
|
||||
materializedClass.op,
|
||||
receives.front().getOperation(),
|
||||
concatType,
|
||||
fragmentType,
|
||||
channelIds,
|
||||
sourceCoreIds,
|
||||
targetCoreIds,
|
||||
concat.getLoc());
|
||||
|
||||
concat->getResult(0).replaceAllUsesWith(replacement);
|
||||
state.rewriter.eraseOp(concat.getOperation());
|
||||
|
||||
for (SpatChannelReceiveOp receive : llvm::reverse(receives))
|
||||
state.rewriter.eraseOp(receive.getOperation());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void compactReceiveConcats(MaterializerState& state) {
|
||||
SmallVector<std::pair<MaterializedClass*, tensor::ConcatOp>, 16> concatOps;
|
||||
|
||||
for (MaterializedClass& materializedClass : state.classes)
|
||||
materializedClass.op->walk([&](tensor::ConcatOp concat) { concatOps.push_back({&materializedClass, concat}); });
|
||||
|
||||
for (auto [materializedClass, concat] : concatOps)
|
||||
compactReceiveConcat(state, *materializedClass, concat);
|
||||
}
|
||||
|
||||
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
MaterializedClass& targetClass,
|
||||
@@ -1086,15 +1347,8 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult emitHostCommunication(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Value originalOutput,
|
||||
Location loc) {
|
||||
(void) keys;
|
||||
(void) loc;
|
||||
|
||||
LogicalResult
|
||||
emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) {
|
||||
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
||||
return success();
|
||||
|
||||
@@ -1115,7 +1369,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
if (failed(
|
||||
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||
return failure();
|
||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
||||
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
|
||||
return failure();
|
||||
state.availableValues[keys.front()][sourceClass.id] = payload;
|
||||
return success();
|
||||
@@ -1129,7 +1383,7 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||
return failure();
|
||||
|
||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
||||
if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput)))
|
||||
return failure();
|
||||
|
||||
for (ProducerKey key : keys)
|
||||
@@ -1146,7 +1400,7 @@ FailureOr<Value> materializeWholeBatchInput(
|
||||
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
|
||||
uint32_t batchLaneCount = static_cast<uint32_t>(batch.getLaneCount());
|
||||
uint32_t batchLaneCount = batch.getLaneCount();
|
||||
SmallVector<Value, 8> fragments;
|
||||
uint32_t lane = 0;
|
||||
|
||||
@@ -1426,6 +1680,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch
|
||||
if (failed(materializeInstanceSlot(state, instance)))
|
||||
return failure();
|
||||
|
||||
compactReceiveConcats(state);
|
||||
|
||||
replaceHostUses(state);
|
||||
if (failed(eraseOldComputeOps(state)))
|
||||
return failure();
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
@@ -35,7 +34,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "PostMergeCompaction.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
@@ -660,13 +658,6 @@ public:
|
||||
|
||||
emitMergeIrCounts("after-materialization", func);
|
||||
|
||||
/*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
emitMergeIrCounts("after-post-merge-compaction", func);*/
|
||||
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
|
||||
@@ -1,532 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "PostMergeCompaction.hpp"
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
|
||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
~ScopedMergePhaseTimer() {
|
||||
if (!enabled)
|
||||
return;
|
||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||
}
|
||||
|
||||
private:
|
||||
bool enabled = false;
|
||||
std::string phase;
|
||||
std::chrono::steady_clock::time_point start;
|
||||
};
|
||||
|
||||
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
static FailureOr<int64_t> getConstantI64Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
|
||||
uint64_t& channelId,
|
||||
uint32_t& sourceCoreId,
|
||||
uint32_t& targetCoreId) {
|
||||
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
|
||||
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
|
||||
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
|
||||
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
|
||||
return false;
|
||||
channelId = static_cast<uint64_t>(*constantChannelId);
|
||||
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
|
||||
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
|
||||
return true;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
|
||||
SmallVector<Value> constants;
|
||||
constants.reserve(values.size());
|
||||
for (int32_t value : values)
|
||||
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
|
||||
return constants;
|
||||
}
|
||||
|
||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct RebatchKey {
|
||||
unsigned inputCount = 0;
|
||||
unsigned resultCount = 0;
|
||||
unsigned weightCount = 0;
|
||||
uint64_t phase = 0;
|
||||
bool hasPhase = false;
|
||||
uint64_t structureHash = 0;
|
||||
|
||||
bool operator==(const RebatchKey& other) const {
|
||||
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||
}
|
||||
};
|
||||
|
||||
struct RebatchKeyInfo {
|
||||
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||
|
||||
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||
|
||||
static unsigned getHashValue(const RebatchKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||
}
|
||||
|
||||
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||
|
||||
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||
llvm::hash_code structureHash =
|
||||
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||
|
||||
Block& body = compute.getBody().front();
|
||||
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||
for (BlockArgument arg : body.getArguments())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||
|
||||
for (Operation& op : body) {
|
||||
structureHash = llvm::hash_combine(
|
||||
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||
for (Type type : op.getResultTypes())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||
for (NamedAttribute attr : op.getAttrs())
|
||||
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||
}
|
||||
|
||||
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||
static_cast<unsigned>(compute.getWeights().size()),
|
||||
phase.value_or(0),
|
||||
phase.has_value(),
|
||||
static_cast<uint64_t>(structureHash)};
|
||||
}
|
||||
|
||||
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
if (!lhs || !rhs)
|
||||
return false;
|
||||
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||
return false;
|
||||
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||
return false;
|
||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||
return false;
|
||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||
return false;
|
||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||
return false;
|
||||
|
||||
auto& lhsBlock = lhs.getBody().front();
|
||||
auto& rhsBlock = rhs.getBody().front();
|
||||
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||
return false;
|
||||
|
||||
DenseMap<Value, Value> mappedValues;
|
||||
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||
if (lhsArg.getType() != rhsArg.getType())
|
||||
return false;
|
||||
mappedValues[lhsArg] = rhsArg;
|
||||
}
|
||||
auto lhsIt = lhsBlock.begin();
|
||||
auto rhsIt = rhsBlock.begin();
|
||||
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||
Operation& lhsOp = *lhsIt;
|
||||
Operation& rhsOp = *rhsIt;
|
||||
|
||||
if (lhsOp.getName() != rhsOp.getName())
|
||||
return false;
|
||||
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||
return false;
|
||||
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||
return false;
|
||||
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||
return false;
|
||||
|
||||
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||
auto mapped = mappedValues.find(lhsOperand);
|
||||
if (mapped != mappedValues.end()) {
|
||||
if (mapped->second != rhsOperand)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (lhsOperand != rhsOperand)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||
return false;
|
||||
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||
mappedValues[lhsResult] = rhsResult;
|
||||
}
|
||||
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseSet<Operation*> consumed;
|
||||
DenseMap<Operation*, size_t> computeOrder;
|
||||
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
computeOrder[compute.getOperation()] = index;
|
||||
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
auto anchor = computes[index];
|
||||
if (consumed.contains(anchor))
|
||||
continue;
|
||||
if (anchor.getInputs().size() > 1)
|
||||
continue;
|
||||
if (!anchor.getResults().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<SpatCompute> group {anchor};
|
||||
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||
if (auto coreId = getComputeCoreId(anchor))
|
||||
usedCoreIds.insert(*coreId);
|
||||
|
||||
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||
if (bucketIt == candidatesByKey.end())
|
||||
continue;
|
||||
|
||||
for (auto candidate : bucketIt->second) {
|
||||
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||
continue;
|
||||
if (consumed.contains(candidate))
|
||||
continue;
|
||||
if (!areEquivalentForRebatch(anchor, candidate))
|
||||
continue;
|
||||
|
||||
if (auto coreId = getComputeCoreId(candidate))
|
||||
if (!usedCoreIds.insert(*coreId).second)
|
||||
continue;
|
||||
|
||||
group.push_back(candidate);
|
||||
}
|
||||
|
||||
if (group.size() <= 1)
|
||||
continue;
|
||||
|
||||
auto insertionAnchor = group.front();
|
||||
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||
llvm::stable_sort(
|
||||
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||
}
|
||||
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(group.size() * anchor.getWeights().size());
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(group.size());
|
||||
bool haveAllCoreIds = true;
|
||||
for (auto compute : group) {
|
||||
llvm::append_range(weights, compute.getWeights());
|
||||
llvm::append_range(inputs, compute.getInputs());
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
haveAllCoreIds = false;
|
||||
else if (haveAllCoreIds)
|
||||
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(insertionAnchor);
|
||||
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||
insertionAnchor.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (haveAllCoreIds)
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& anchorBlock = anchor.getBody().front();
|
||||
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||
for (Operation& anchorOp : anchorBlock) {
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||
BatchReceiveEntry entry;
|
||||
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchReceiveEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getOutput().getType(),
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues);
|
||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||
BatchSendEntry entry;
|
||||
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
|
||||
return;
|
||||
entries.push_back(entry);
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchSendEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
channelIdValues,
|
||||
sourceCoreIdValues,
|
||||
targetCoreIdValues,
|
||||
mapper.lookup(sendOp.getInput()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
}
|
||||
|
||||
for (auto compute : group) {
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
consumed.insert(compute);
|
||||
rewriter.eraseOp(compute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
}
|
||||
|
||||
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
{
|
||||
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||
orderBilateralChannelOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||
rebatchEquivalentComputes(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||
compactRegularOpRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||
compactRowWiseWvmmRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,12 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user