better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
spat.map
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
@@ -35,6 +37,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
@@ -147,7 +150,7 @@ static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
@@ -304,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
|
||||
|
||||
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
|
||||
if (!coreIds.empty())
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
|
||||
@@ -548,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
}
|
||||
|
||||
static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
outputTypes.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
outputTypes.push_back(entry.op.getOutput().getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
struct SendEntry {
|
||||
spatial::SpatChannelSendOp op;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
@@ -755,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
rebatched.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (haveAllCoreIds)
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
@@ -1879,6 +1747,9 @@ public:
|
||||
|
||||
rebatchEquivalentComputes(func, nextChannelId);
|
||||
compactScalarChannelRuns(func, nextChannelId);
|
||||
compactBatchChannelRuns(func);
|
||||
compactRegularOpRuns(func);
|
||||
compactRowWiseWvmmRuns(func);
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
@@ -2049,7 +1920,7 @@ private:
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdAttrName,
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName,
|
||||
rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount)));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
|
||||
@@ -0,0 +1,577 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
enum class RegularStepKind {
|
||||
Wvmm,
|
||||
VAddLhs,
|
||||
VAddRhs,
|
||||
};
|
||||
|
||||
struct RegularStep {
|
||||
RegularStepKind kind;
|
||||
int32_t weightIndex = 0;
|
||||
Value invariantOperand;
|
||||
Type resultType;
|
||||
};
|
||||
|
||||
struct RegularChunk {
|
||||
Operation* startOp = nullptr;
|
||||
SmallVector<Operation*> ops;
|
||||
SmallVector<RegularStep> steps;
|
||||
Value input;
|
||||
Value output;
|
||||
};
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||
&& lhs.resultType == rhs.resultType;
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) {
|
||||
if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType()
|
||||
|| lhs.steps.size() != rhs.steps.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps),
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
chunk.input = startOp.getInput();
|
||||
chunk.output = startOp.getOutput();
|
||||
chunk.ops.push_back(startOp.getOperation());
|
||||
chunk.steps.push_back(
|
||||
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
|
||||
|
||||
Value currentValue = startOp.getOutput();
|
||||
while (currentValue.hasOneUse()) {
|
||||
Operation* user = *currentValue.getUsers().begin();
|
||||
if (user->getBlock() != startOp->getBlock())
|
||||
break;
|
||||
|
||||
auto vaddOp = dyn_cast<spatial::SpatVAddOp>(user);
|
||||
if (!vaddOp)
|
||||
break;
|
||||
|
||||
if (vaddOp.getLhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||
else if (vaddOp.getRhs() == currentValue)
|
||||
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||
else
|
||||
break;
|
||||
|
||||
chunk.ops.push_back(vaddOp);
|
||||
chunk.output = vaddOp.getOutput();
|
||||
currentValue = vaddOp.getOutput();
|
||||
}
|
||||
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
|
||||
auto* block = rewriter.createBlock(
|
||||
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
|
||||
IRMapping mapping;
|
||||
mapping.map(anchorChunk.input, block->getArgument(0));
|
||||
|
||||
for (Operation* op : anchorChunk.ops) {
|
||||
Operation* cloned = rewriter.clone(*op, mapping);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||
mapping.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(
|
||||
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
inputs.reserve(run.size());
|
||||
outputTypes.reserve(run.size());
|
||||
for (const RegularChunk& chunk : run) {
|
||||
inputs.push_back(chunk.input);
|
||||
outputTypes.push_back(chunk.output.getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
auto mapOp =
|
||||
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
|
||||
buildRegularMapBody(mapOp, anchorChunk, rewriter);
|
||||
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(mapOp.getResult(index));
|
||||
}
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
for (const RegularChunk& chunk : run)
|
||||
llvm::append_range(opsToErase, chunk.ops);
|
||||
for (Operation* op : llvm::reverse(opsToErase))
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
outputTypes.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
outputTypes.push_back(entry.op.getOutput().getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
struct SendEntry {
|
||||
spatial::SpatChannelSendOp op;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
channelIds.reserve(sortedEntries.size());
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
Block& block = batch.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Type> outputTypes;
|
||||
outputTypes.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
outputTypes.push_back(op.getOutput().getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
TypeRange(outputTypes),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
inputs.push_back(op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
spatial::SpatChannelSendManyBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
ValueRange(inputs));
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
auto compactInBlock = [&](Block& block) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
if (!startOp) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto anchorChunk = analyzeRegularChunk(startOp);
|
||||
if (failed(anchorChunk)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
break;
|
||||
|
||||
auto candidateChunk = analyzeRegularChunk(candidateStart);
|
||||
if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk))
|
||||
break;
|
||||
|
||||
run.push_back(*candidateChunk);
|
||||
runIt = std::next(runIt, static_cast<std::ptrdiff_t>(candidateChunk->ops.size()));
|
||||
}
|
||||
|
||||
if (run.size() <= 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
compactRegularChunkRun(rewriter, run);
|
||||
it = runIt;
|
||||
}
|
||||
};
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||
compactInBlock(compute.getBody().front());
|
||||
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>())
|
||||
compactInBlock(batch.getBody().front());
|
||||
}
|
||||
|
||||
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
if (!wvmmOp) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto extractRowsOp = wvmmOp.getInput().getDefiningOp<spatial::SpatExtractRowsOp>();
|
||||
auto rowResult = dyn_cast<OpResult>(wvmmOp.getInput());
|
||||
auto outputType = dyn_cast<RankedTensorType>(wvmmOp.getOutput().getType());
|
||||
if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType
|
||||
|| !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatWeightedVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||
break;
|
||||
|
||||
run.push_back(current);
|
||||
++expectedRow;
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() <= 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!run.front().getOutput().hasOneUse()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
auto concatUse = run.front().getOutput().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||
if (!concatOp) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||
bool validConcatRun = true;
|
||||
for (auto [index, op] : llvm::enumerate(run)) {
|
||||
if (!op.getOutput().hasOneUse()) {
|
||||
validConcatRun = false;
|
||||
break;
|
||||
}
|
||||
OpOperand& use = *op.getOutput().getUses().begin();
|
||||
if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) {
|
||||
validConcatRun = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!validConcatRun) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(wvmmOp.getInput().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t inputCols = inputType.getShape()[1];
|
||||
int64_t outputCols = outputType.getShape()[1];
|
||||
if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
int64_t runLength = static_cast<int64_t>(run.size());
|
||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Block* loopBlock = loop.getBody();
|
||||
rewriter.setInsertionPointToStart(loopBlock);
|
||||
Value iv = loopBlock->getArgument(0);
|
||||
Value acc = loopBlock->getArgument(1);
|
||||
|
||||
Value sourceRow = iv;
|
||||
if (firstRow != 0) {
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
inputType,
|
||||
extractRowsOp.getInput(),
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
SmallVector<Value> newConcatInputs;
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == concatStartIndex)
|
||||
newConcatInputs.push_back(loop.getResult(0));
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
||||
newConcatInputs.push_back(operand);
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = loop->getIterator();
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user