compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -88,7 +88,14 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
|||||||
|
|
||||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||||
assert(root && "expected valid root op");
|
assert(root && "expected valid root op");
|
||||||
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||||
|
auto weights = coreOp.getWeights();
|
||||||
|
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||||
|
if (weightIndex < weights.size())
|
||||||
|
callback(coreOp->getOpOperand(weightIndex));
|
||||||
|
});
|
||||||
|
});
|
||||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
auto weights = coreBatchOp.getWeights();
|
auto weights = coreBatchOp.getWeights();
|
||||||
for (auto weight : weights)
|
for (auto weight : weights)
|
||||||
|
|||||||
@@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
|||||||
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
||||||
|
MemoryReportRow result = lhs;
|
||||||
|
result.numAlloca += rhs.numAlloca;
|
||||||
|
result.sizeAlloca += rhs.sizeAlloca;
|
||||||
|
result.numGlobal += rhs.numGlobal;
|
||||||
|
result.sizeGlobal += rhs.sizeGlobal;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
MemoryReportRow PimMemory::getReportRow() const {
|
MemoryReportRow PimMemory::getReportRow() const {
|
||||||
MemoryReportRow row;
|
MemoryReportRow row;
|
||||||
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
||||||
@@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() {
|
|||||||
hostReportRow = hostMem.getReportRow();
|
hostReportRow = hostMem.getReportRow();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimAcceleratorMemory::reportCore(size_t coreId) {
|
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||||
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()});
|
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
|
||||||
|
MemoryReportEntry entry;
|
||||||
|
entry.kind = MemoryReportEntry::Kind::Batch;
|
||||||
|
entry.id = batchId;
|
||||||
|
llvm::append_range(entry.coreIds, coreIds);
|
||||||
|
entry.row = row;
|
||||||
|
reportEntries.push_back(std::move(entry));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimAcceleratorMemory::flushReport() {
|
void PimAcceleratorMemory::flushReport() {
|
||||||
@@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() {
|
|||||||
printMemoryReportRow(os, *hostReportRow);
|
printMemoryReportRow(os, *hostReportRow);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!coreReportRows.empty()) {
|
if (!reportEntries.empty()) {
|
||||||
if (hostReportRow.has_value())
|
if (hostReportRow.has_value())
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) {
|
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
|
||||||
const MemoryReportRow& lhsRow = lhs.second;
|
if (lhs.kind != rhs.kind)
|
||||||
const MemoryReportRow& rhsRow = rhs.second;
|
return lhs.kind == MemoryReportEntry::Kind::Batch;
|
||||||
|
|
||||||
|
const MemoryReportRow& lhsRow = lhs.row;
|
||||||
|
const MemoryReportRow& rhsRow = rhs.row;
|
||||||
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
||||||
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
||||||
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
||||||
@@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() {
|
|||||||
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
||||||
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
||||||
return lhsRow.numGlobal > rhsRow.numGlobal;
|
return lhsRow.numGlobal > rhsRow.numGlobal;
|
||||||
return lhs.first < rhs.first;
|
return lhs.id < rhs.id;
|
||||||
});
|
});
|
||||||
|
|
||||||
for (size_t index = 0; index < coreReportRows.size();) {
|
for (size_t index = 0; index < reportEntries.size();) {
|
||||||
size_t runEnd = index + 1;
|
size_t runEnd = index + 1;
|
||||||
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second)
|
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
||||||
|
&& reportEntries[runEnd].row == reportEntries[index].row) {
|
||||||
++runEnd;
|
++runEnd;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<size_t, 8> coreIds;
|
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
|
||||||
coreIds.reserve(runEnd - index);
|
os << "Batch ";
|
||||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
||||||
coreIds.push_back(coreReportRows[coreIndex].first);
|
if (batchIndex != index)
|
||||||
|
os << ",\n ";
|
||||||
os << "Core ";
|
os << reportEntries[batchIndex].id << " (cores ";
|
||||||
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds));
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
|
||||||
|
os << ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
llvm::SmallVector<int32_t, 8> coreIds;
|
||||||
|
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||||
|
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
|
||||||
|
os << "Core ";
|
||||||
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||||
|
}
|
||||||
os << ":\n";
|
os << ":\n";
|
||||||
printMemoryReportRow(os, coreReportRows[index].second);
|
printMemoryReportRow(os, reportEntries[index].row);
|
||||||
if (runEnd < coreReportRows.size())
|
if (runEnd < reportEntries.size())
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
index = runEnd;
|
index = runEnd;
|
||||||
@@ -678,7 +711,6 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
|||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
};
|
};
|
||||||
|
|
||||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
|
||||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
@@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
|
|
||||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||||
@@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
// This implementation always assigns one crossbar per group.
|
// This implementation always assigns one crossbar per group.
|
||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t maxCoreId = 0;
|
size_t maxCoreId = 0;
|
||||||
|
uint64_t nextBatchReportId = 0;
|
||||||
|
|
||||||
// Create Weight Folder
|
// Create Weight Folder
|
||||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||||
@@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||||
memory.reportCore(coreId);
|
|
||||||
|
|
||||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||||
if (processedOperations < 0)
|
if (processedOperations < 0)
|
||||||
@@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
if (auto err = emitCore(coreOp, false))
|
if (auto err = emitCore(coreOp, false))
|
||||||
return err;
|
return err;
|
||||||
|
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||||
|
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||||
|
.getReportRow());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||||
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
SmallVector<int32_t> reportedCoreIds;
|
||||||
|
reportedCoreIds.reserve(batchCoreIds.size());
|
||||||
|
MemoryReportRow batchRow;
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||||
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||||
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||||
|
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||||
laneResult = emitCore(coreOp, true);
|
laneResult = emitCore(coreOp, true);
|
||||||
|
if (laneResult == CompilerSuccess)
|
||||||
|
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
|
||||||
return laneResult == CompilerSuccess ? success() : failure();
|
return laneResult == CompilerSuccess ? success() : failure();
|
||||||
})))
|
})))
|
||||||
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||||
}
|
}
|
||||||
|
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
|
||||||
}
|
}
|
||||||
|
|
||||||
memory.flushReport();
|
memory.flushReport();
|
||||||
|
|||||||
@@ -33,6 +33,18 @@ struct MemoryReportRow {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MemoryReportEntry {
|
||||||
|
enum class Kind {
|
||||||
|
Core,
|
||||||
|
Batch
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = Kind::Core;
|
||||||
|
uint64_t id = 0;
|
||||||
|
llvm::SmallVector<int32_t, 8> coreIds;
|
||||||
|
MemoryReportRow row;
|
||||||
|
};
|
||||||
|
|
||||||
class PimMemory {
|
class PimMemory {
|
||||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||||
@@ -66,7 +78,7 @@ private:
|
|||||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||||
std::fstream fileReport;
|
std::fstream fileReport;
|
||||||
std::optional<MemoryReportRow> hostReportRow;
|
std::optional<MemoryReportRow> hostReportRow;
|
||||||
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows;
|
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
@@ -86,7 +98,8 @@ public:
|
|||||||
|
|
||||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||||
void reportHost();
|
void reportHost();
|
||||||
void reportCore(size_t coreId);
|
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||||
|
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
|
||||||
void flushReport();
|
void flushReport();
|
||||||
void clean(mlir::Operation* op);
|
void clean(mlir::Operation* op);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -103,7 +103,6 @@ SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
|||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
};
|
};
|
||||||
|
|
||||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
|
||||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se
|
|||||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||||
|
|
||||||
Value input = mapper.lookup(sendTensorBatchOp.getInput());
|
pim::PimSendTensorBatchOp::create(rewriter,
|
||||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
sendTensorBatchOp.getLoc(),
|
||||||
if (concatOp.getDim() == 0)
|
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||||
if (Value packedInput =
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
|
|
||||||
input = packedInput;
|
|
||||||
|
|
||||||
pim::PimSendTensorBatchOp::create(
|
|
||||||
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||||
|
|||||||
@@ -21,12 +21,6 @@ def spatToPimVMM : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimMVM : Pat<
|
|
||||||
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
|
|
||||||
(PimMVMOp $weightIndex, $vector,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatToPimVVAdd : Pat<
|
def spatToPimVVAdd : Pat<
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
(PimVVAddOp $a, $b,
|
(PimVVAddOp $a, $b,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
||||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
|
|||||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||||
|
|
||||||
rewriter.setInsertionPoint(sendTensorOp);
|
rewriter.setInsertionPoint(sendTensorOp);
|
||||||
Value input = sendTensorOp.getInput();
|
PimSendTensorOp::create(
|
||||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
if (concatOp.getDim() == 0)
|
|
||||||
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
|
|
||||||
input = packedInput;
|
|
||||||
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
||||||
rewriter.eraseOp(sendTensorOp);
|
rewriter.eraseOp(sendTensorOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
|||||||
rewriter.replaceOp(extractRowsOp, replacements);
|
rewriter.replaceOp(extractRowsOp, replacements);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createPackedExtractRowsSlice(
|
|
||||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
|
||||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
|
||||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
|
||||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
|
||||||
return {};
|
|
||||||
|
|
||||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
|
||||||
if (ShapedType::isDynamic(rowsPerValue))
|
|
||||||
return {};
|
|
||||||
|
|
||||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
|
||||||
SmallVector<OpFoldResult> offsets;
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides;
|
|
||||||
offsets.reserve(inputType.getRank());
|
|
||||||
sizes.reserve(inputType.getRank());
|
|
||||||
strides.reserve(inputType.getRank());
|
|
||||||
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(0));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
|
||||||
.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||||
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
|
|||||||
.getResult());
|
.getResult());
|
||||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
|
|
||||||
populateTensorPackingPatterns(tensorPackingPatterns);
|
|
||||||
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
|
|
||||||
|
|
||||||
auto eraseUnusedOps = [&](auto tag) {
|
auto eraseUnusedOps = [&](auto tag) {
|
||||||
using OpTy = decltype(tag);
|
using OpTy = decltype(tag);
|
||||||
SmallVector<OpTy> ops;
|
SmallVector<OpTy> ops;
|
||||||
|
|||||||
@@ -3,26 +3,6 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
|
|
||||||
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
|
|
||||||
if (op.getDim() != 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
|
|
||||||
if (!packed)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, packed);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||||
@@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count
|
|||||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value extractPackedChunk(
|
||||||
|
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
|
||||||
|
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
|
||||||
|
if (packedType && packedType == chunkType && index == 0)
|
||||||
|
return packedValue;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides;
|
||||||
|
offsets.reserve(chunkType.getRank());
|
||||||
|
sizes.reserve(chunkType.getRank());
|
||||||
|
strides.reserve(chunkType.getRank());
|
||||||
|
|
||||||
|
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||||
|
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||||
|
offsets.push_back(builder.getIndexAttr(0));
|
||||||
|
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value createPackedExtractRowsSlice(
|
||||||
|
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
|
||||||
|
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||||
|
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||||
|
if (ShapedType::isDynamic(rowsPerValue))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides;
|
||||||
|
offsets.reserve(inputType.getRank());
|
||||||
|
sizes.reserve(inputType.getRank());
|
||||||
|
strides.reserve(inputType.getRank());
|
||||||
|
|
||||||
|
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||||
|
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||||
|
offsets.push_back(builder.getIndexAttr(0));
|
||||||
|
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool coversWholeSource = packedType == inputType && startIndex == 0;
|
||||||
|
if (coversWholeSource)
|
||||||
|
return extractRowsOp.getInput();
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
||||||
if (values.empty())
|
if (values.empty())
|
||||||
return {};
|
return {};
|
||||||
@@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
|
|||||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
|
||||||
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -3,11 +3,21 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
||||||
|
mlir::Value extractPackedChunk(mlir::Value packedValue,
|
||||||
|
mlir::RankedTensorType chunkType,
|
||||||
|
unsigned index,
|
||||||
|
mlir::OpBuilder& builder,
|
||||||
|
mlir::Location loc);
|
||||||
|
mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp,
|
||||||
|
unsigned startIndex,
|
||||||
|
unsigned count,
|
||||||
|
mlir::OpBuilder& builder,
|
||||||
|
mlir::Location loc);
|
||||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||||
|
|
||||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Matrix-vector multiplication: c = a * b";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr:$weightIndex,
|
|
||||||
PimTensor:$input,
|
|
||||||
PimTensor:$outputBuffer
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Element-wise addition: c = a + b";
|
let summary = "Element-wise addition: c = a + b";
|
||||||
|
|
||||||
|
|||||||
@@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto mvmOp = cast<PimMVMOp>(op);
|
|
||||||
|
|
||||||
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
|
|
||||||
if (failed(inputOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
|
|
||||||
if (failed(outputBufferOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
|
||||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
@@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||||
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
|
||||||
|
|
||||||
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||||
|
|||||||
@@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let assemblyFormat = [{
|
|
||||||
$input attr-dict `:` type($input)
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||||
@@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let assemblyFormat = [{
|
|
||||||
attr-dict `:` type($output)
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||||
@@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let assemblyFormat = [{
|
|
||||||
$input attr-dict `:` type($input)
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||||
@@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let assemblyFormat = [{
|
|
||||||
attr-dict `:` type($output)
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|||||||
return parser.getBuilder().getI32IntegerAttr(value);
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TensorSendOpTy>
|
||||||
|
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||||
|
printer << " ";
|
||||||
|
printer.printOperand(op.getInput());
|
||||||
|
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
op->getAttrs(),
|
||||||
|
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printer.printType(op.getInput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorReceiveOpTy>
|
||||||
|
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||||
|
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
op->getAttrs(),
|
||||||
|
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printer.printType(op.getOutput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
OpAsmParser::UnresolvedOperand input;
|
||||||
|
Type inputType;
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
if (parser.parseOperand(input))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||||
|
if (hasMetadata) {
|
||||||
|
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||||
|
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||||
|
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (hasMetadata
|
||||||
|
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||||
|
|| result.attributes.get("targetCoreIds")))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||||
|
if (hasMetadata) {
|
||||||
|
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||||
|
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||||
|
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
return parser.resolveOperand(input, inputType, result.operands);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
Type outputType;
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||||
|
if (hasMetadata) {
|
||||||
|
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||||
|
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||||
|
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (hasMetadata
|
||||||
|
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||||
|
|| result.attributes.get("targetCoreIds")))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||||
|
if (hasMetadata) {
|
||||||
|
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||||
|
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||||
|
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
result.addTypes(outputType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||||
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
return parser.parseRegion(*body, regionArgs);
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||||
|
|
||||||
|
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
return parseTensorSendOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printer.printOperand(getInput());
|
printer.printOperand(getInput());
|
||||||
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
|||||||
return parser.resolveOperand(input, inputType, result.operands);
|
return parser.resolveOperand(input, inputType, result.operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||||
|
|
||||||
|
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
return parseTensorSendOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||||
|
|
||||||
|
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
return parseTensorReceiveOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(
|
||||||
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||||
|
|
||||||
|
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
return parseTensorReceiveOp(parser, result);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1436,6 +1436,21 @@ public:
|
|||||||
compactBatchChannelRuns(func);
|
compactBatchChannelRuns(func);
|
||||||
compactRegularOpRuns(func);
|
compactRegularOpRuns(func);
|
||||||
compactRowWiseWvmmRuns(func);
|
compactRowWiseWvmmRuns(func);
|
||||||
|
compactScalarChannelRuns(func, nextChannelId);
|
||||||
|
compactBatchChannelRuns(func);
|
||||||
|
|
||||||
|
auto eraseUnusedOps = [&](auto tag) {
|
||||||
|
using OpTy = decltype(tag);
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
func.walk([&](OpTy op) { ops.push_back(op); });
|
||||||
|
for (auto op : llvm::reverse(ops))
|
||||||
|
if (op->use_empty())
|
||||||
|
op.erase();
|
||||||
|
};
|
||||||
|
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||||
|
|
||||||
if (!sortTopologically(&func.getBody().front())) {
|
if (!sortTopologically(&func.getBody().front())) {
|
||||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "RegularOpCompaction.hpp"
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -41,133 +42,75 @@ struct RegularChunk {
|
|||||||
Value output;
|
Value output;
|
||||||
};
|
};
|
||||||
|
|
||||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
if (values.empty() || !values.front().hasOneUse())
|
||||||
packedShape[0] *= count;
|
return {};
|
||||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value
|
OpOperand& firstUse = *values.front().getUses().begin();
|
||||||
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
|
auto concatOp = dyn_cast<spatial::SpatConcatOp>(firstUse.getOwner());
|
||||||
SmallVector<OpFoldResult> offsets;
|
if (!concatOp)
|
||||||
SmallVector<OpFoldResult> sizes;
|
return {};
|
||||||
SmallVector<OpFoldResult> strides;
|
|
||||||
offsets.reserve(chunkType.getRank());
|
|
||||||
sizes.reserve(chunkType.getRank());
|
|
||||||
strides.reserve(chunkType.getRank());
|
|
||||||
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
startOperandIndex = firstUse.getOperandNumber();
|
||||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
|
for (auto [index, value] : llvm::enumerate(values)) {
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
if (!value.hasOneUse())
|
||||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
return {};
|
||||||
offsets.push_back(rewriter.getIndexAttr(0));
|
OpOperand& use = *value.getUses().begin();
|
||||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
|
if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index)
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
return concatOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createPackedExtractRowsSlice(
|
static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp,
|
||||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
unsigned startOperandIndex,
|
||||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
unsigned operandCount,
|
||||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
Value packedValue,
|
||||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
IRRewriter& rewriter) {
|
||||||
return {};
|
SmallVector<Value> newInputs;
|
||||||
|
newInputs.reserve(concatOp.getInputs().size() - operandCount + 1);
|
||||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||||
if (ShapedType::isDynamic(rowsPerValue))
|
if (operandIndex == startOperandIndex)
|
||||||
return {};
|
newInputs.push_back(packedValue);
|
||||||
|
if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount)
|
||||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
newInputs.push_back(operand);
|
||||||
SmallVector<OpFoldResult> offsets;
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides;
|
|
||||||
offsets.reserve(inputType.getRank());
|
|
||||||
sizes.reserve(inputType.getRank());
|
|
||||||
strides.reserve(inputType.getRank());
|
|
||||||
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(0));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
}
|
||||||
|
if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) {
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
rewriter.replaceOp(concatOp, newInputs.front());
|
||||||
.getResult();
|
return;
|
||||||
|
}
|
||||||
|
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
|
static RankedTensorType
|
||||||
if (values.empty())
|
getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) {
|
||||||
return {};
|
auto firstType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex].getType());
|
||||||
if (values.size() == 1)
|
if (!firstType || !firstType.hasStaticShape())
|
||||||
return values.front();
|
|
||||||
|
|
||||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
|
||||||
if (!firstSliceOp)
|
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
int64_t axis = concatOp.getAxis();
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
if (axis < 0 || axis >= firstType.getRank())
|
||||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
|
||||||
|| firstType.getRank() == 0)
|
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
SmallVector<int64_t> shape(firstType.getShape().begin(), firstType.getShape().end());
|
||||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
shape[axis] = 0;
|
||||||
};
|
for (unsigned index = 0; index < operandCount; ++index) {
|
||||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
auto operandType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex + index].getType());
|
||||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank())
|
||||||
return {};
|
|
||||||
|
|
||||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
|
||||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
|
||||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
|
||||||
int64_t rowsPerValue = firstSizes[0];
|
|
||||||
if (ShapedType::isDynamic(rowsPerValue))
|
|
||||||
return {};
|
|
||||||
|
|
||||||
for (size_t index = 1; index < values.size(); ++index) {
|
|
||||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
|
||||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
|
||||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
|
||||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
|
||||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
for (int64_t dim = 0; dim < firstType.getRank(); ++dim) {
|
||||||
return {};
|
if (dim == axis)
|
||||||
|
continue;
|
||||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
if (operandType.getShape()[dim] != shape[dim])
|
||||||
return {};
|
|
||||||
|
|
||||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
|
||||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
|
||||||
return {};
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
shape[axis] += operandType.getShape()[axis];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
return RankedTensorType::get(shape, firstType.getElementType());
|
||||||
SmallVector<OpFoldResult> offsets;
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
SmallVector<OpFoldResult> strides;
|
|
||||||
offsets.reserve(firstType.getRank());
|
|
||||||
sizes.reserve(firstType.getRank());
|
|
||||||
strides.reserve(firstType.getRank());
|
|
||||||
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
|
|
||||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
|
||||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
|
|
||||||
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
|
||||||
.getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||||
@@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
|||||||
return {};
|
return {};
|
||||||
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
||||||
return {};
|
return {};
|
||||||
|
return {};
|
||||||
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||||
@@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
SmallVector<Value> outputs;
|
||||||
Value replacement = extractPackedChunk(
|
outputs.reserve(run.size());
|
||||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
for (const RegularChunk& chunk : run)
|
||||||
Value output = chunk.output;
|
outputs.push_back(chunk.output);
|
||||||
output.replaceAllUsesWith(replacement);
|
|
||||||
|
unsigned concatStartIndex = 0;
|
||||||
|
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||||
|
auto concatPackedType = concatOp
|
||||||
|
? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||||
|
: RankedTensorType {};
|
||||||
|
|
||||||
|
if (concatOp && concatPackedType == packedOutputType) {
|
||||||
|
replaceConcatRunWithPackedValue(
|
||||||
|
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), loop.getResult(0), rewriter);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||||
|
Value replacement = extractPackedChunk(
|
||||||
|
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||||
|
Value output = chunk.output;
|
||||||
|
output.replaceAllUsesWith(replacement);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Operation*> opsToErase;
|
SmallVector<Operation*> opsToErase;
|
||||||
@@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||||
|
SmallVector<Value> sortedOutputs;
|
||||||
|
sortedOutputs.reserve(sortedEntries.size());
|
||||||
|
for (ReceiveEntry& entry : sortedEntries)
|
||||||
|
sortedOutputs.push_back(entry.op.getOutput());
|
||||||
|
|
||||||
|
unsigned concatStartIndex = 0;
|
||||||
|
auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex);
|
||||||
|
auto concatPackedType =
|
||||||
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||||
|
: RankedTensorType {};
|
||||||
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||||
@@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
if (concatOp && concatPackedType) {
|
||||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
replaceConcatRunWithPackedValue(concatOp,
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
concatStartIndex,
|
||||||
|
static_cast<unsigned>(sortedOutputs.size()),
|
||||||
|
compactReceive.getOutput(),
|
||||||
|
rewriter);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||||
|
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||||
|
}
|
||||||
for (auto op : run)
|
for (auto op : run)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
@@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||||
|
SmallVector<Value> outputs;
|
||||||
|
outputs.reserve(run.size());
|
||||||
|
for (auto op : run)
|
||||||
|
outputs.push_back(op.getOutput());
|
||||||
|
|
||||||
|
unsigned concatStartIndex = 0;
|
||||||
|
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||||
|
auto concatPackedType =
|
||||||
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||||
|
: RankedTensorType {};
|
||||||
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||||
@@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
for (auto [index, op] : llvm::enumerate(run))
|
if (concatOp && concatPackedType) {
|
||||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
replaceConcatRunWithPackedValue(
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto [index, op] : llvm::enumerate(run))
|
||||||
|
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||||
|
}
|
||||||
for (auto op : run)
|
for (auto op : run)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user