compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -88,7 +88,14 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
|
||||
@@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
||||
}
|
||||
|
||||
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
||||
MemoryReportRow result = lhs;
|
||||
result.numAlloca += rhs.numAlloca;
|
||||
result.sizeAlloca += rhs.sizeAlloca;
|
||||
result.numGlobal += rhs.numGlobal;
|
||||
result.sizeGlobal += rhs.sizeGlobal;
|
||||
return result;
|
||||
}
|
||||
|
||||
MemoryReportRow PimMemory::getReportRow() const {
|
||||
MemoryReportRow row;
|
||||
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
||||
@@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() {
|
||||
hostReportRow = hostMem.getReportRow();
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::reportCore(size_t coreId) {
|
||||
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()});
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
|
||||
MemoryReportEntry entry;
|
||||
entry.kind = MemoryReportEntry::Kind::Batch;
|
||||
entry.id = batchId;
|
||||
llvm::append_range(entry.coreIds, coreIds);
|
||||
entry.row = row;
|
||||
reportEntries.push_back(std::move(entry));
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::flushReport() {
|
||||
@@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() {
|
||||
printMemoryReportRow(os, *hostReportRow);
|
||||
}
|
||||
|
||||
if (!coreReportRows.empty()) {
|
||||
if (!reportEntries.empty()) {
|
||||
if (hostReportRow.has_value())
|
||||
os << "\n";
|
||||
|
||||
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) {
|
||||
const MemoryReportRow& lhsRow = lhs.second;
|
||||
const MemoryReportRow& rhsRow = rhs.second;
|
||||
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
|
||||
if (lhs.kind != rhs.kind)
|
||||
return lhs.kind == MemoryReportEntry::Kind::Batch;
|
||||
|
||||
const MemoryReportRow& lhsRow = lhs.row;
|
||||
const MemoryReportRow& rhsRow = rhs.row;
|
||||
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
||||
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
||||
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
||||
@@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() {
|
||||
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
||||
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
||||
return lhsRow.numGlobal > rhsRow.numGlobal;
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
|
||||
for (size_t index = 0; index < coreReportRows.size();) {
|
||||
for (size_t index = 0; index < reportEntries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second)
|
||||
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
||||
&& reportEntries[runEnd].row == reportEntries[index].row) {
|
||||
++runEnd;
|
||||
}
|
||||
|
||||
llvm::SmallVector<size_t, 8> coreIds;
|
||||
coreIds.reserve(runEnd - index);
|
||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||
coreIds.push_back(coreReportRows[coreIndex].first);
|
||||
|
||||
os << "Core ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds));
|
||||
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
|
||||
os << "Batch ";
|
||||
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
||||
if (batchIndex != index)
|
||||
os << ",\n ";
|
||||
os << reportEntries[batchIndex].id << " (cores ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
|
||||
os << ")";
|
||||
}
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
|
||||
os << "Core ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||
}
|
||||
os << ":\n";
|
||||
printMemoryReportRow(os, coreReportRows[index].second);
|
||||
if (runEnd < coreReportRows.size())
|
||||
printMemoryReportRow(os, reportEntries[index].row);
|
||||
if (runEnd < reportEntries.size())
|
||||
os << "\n";
|
||||
|
||||
index = runEnd;
|
||||
@@ -678,7 +711,6 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
@@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
// This implementation always assigns one crossbar per group.
|
||||
json::Object xbarsPerArrayGroup;
|
||||
size_t maxCoreId = 0;
|
||||
uint64_t nextBatchReportId = 0;
|
||||
|
||||
// Create Weight Folder
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
@@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||
memory.reportCore(coreId);
|
||||
|
||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||
if (processedOperations < 0)
|
||||
@@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
return err;
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||
.getReportRow());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||
SmallVector<int32_t> reportedCoreIds;
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
laneResult = emitCore(coreOp, true);
|
||||
if (laneResult == CompilerSuccess)
|
||||
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
|
||||
return laneResult == CompilerSuccess ? success() : failure();
|
||||
})))
|
||||
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||
}
|
||||
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
|
||||
}
|
||||
|
||||
memory.flushReport();
|
||||
|
||||
@@ -33,6 +33,18 @@ struct MemoryReportRow {
|
||||
}
|
||||
};
|
||||
|
||||
struct MemoryReportEntry {
|
||||
enum class Kind {
|
||||
Core,
|
||||
Batch
|
||||
};
|
||||
|
||||
Kind kind = Kind::Core;
|
||||
uint64_t id = 0;
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
MemoryReportRow row;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
@@ -66,7 +78,7 @@ private:
|
||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||
std::fstream fileReport;
|
||||
std::optional<MemoryReportRow> hostReportRow;
|
||||
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows;
|
||||
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
@@ -86,7 +98,8 @@ public:
|
||||
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
void reportHost();
|
||||
void reportCore(size_t coreId);
|
||||
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
|
||||
void flushReport();
|
||||
void clean(mlir::Operation* op);
|
||||
};
|
||||
|
||||
@@ -103,7 +103,6 @@ SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se
|
||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
Value input = mapper.lookup(sendTensorBatchOp.getInput());
|
||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
||||
if (concatOp.getDim() == 0)
|
||||
if (Value packedInput =
|
||||
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
|
||||
input = packedInput;
|
||||
|
||||
pim::PimSendTensorBatchOp::create(
|
||||
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
|
||||
@@ -21,12 +21,6 @@ def spatToPimVMM : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimMVM : Pat<
|
||||
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimMVMOp $weightIndex, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVAdd : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVVAddOp $a, $b,
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(sendTensorOp);
|
||||
Value input = sendTensorOp.getInput();
|
||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
||||
if (concatOp.getDim() == 0)
|
||||
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
|
||||
input = packedInput;
|
||||
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
PimSendTensorOp::create(
|
||||
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(sendTensorOp);
|
||||
}
|
||||
|
||||
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
|
||||
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
|
||||
populateTensorPackingPatterns(tensorPackingPatterns);
|
||||
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
|
||||
@@ -3,26 +3,6 @@
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
|
||||
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
|
||||
if (op.getDim() != 0)
|
||||
return failure();
|
||||
|
||||
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
|
||||
if (!packed)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, packed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
@@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
Value extractPackedChunk(
|
||||
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
|
||||
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
|
||||
if (packedType && packedType == chunkType && index == 0)
|
||||
return packedValue;
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(chunkType.getRank());
|
||||
sizes.reserve(chunkType.getRank());
|
||||
strides.reserve(chunkType.getRank());
|
||||
|
||||
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||
offsets.push_back(builder.getIndexAttr(0));
|
||||
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(builder.getIndexAttr(0));
|
||||
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
}
|
||||
|
||||
bool coversWholeSource = packedType == inputType && startIndex == 0;
|
||||
if (coversWholeSource)
|
||||
return extractRowsOp.getInput();
|
||||
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
@@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,11 +3,21 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
||||
mlir::Value extractPackedChunk(mlir::Value packedValue,
|
||||
mlir::RankedTensorType chunkType,
|
||||
unsigned index,
|
||||
mlir::OpBuilder& builder,
|
||||
mlir::Location loc);
|
||||
mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp,
|
||||
unsigned startIndex,
|
||||
unsigned count,
|
||||
mlir::OpBuilder& builder,
|
||||
mlir::Location loc);
|
||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Matrix-vector multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Element-wise addition: c = a + b";
|
||||
|
||||
|
||||
@@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
}
|
||||
};
|
||||
|
||||
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
@@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
|
||||
|
||||
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
|
||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||
|
||||
@@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
@@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
@@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
@@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1436,6 +1436,21 @@ public:
|
||||
compactBatchChannelRuns(func);
|
||||
compactRegularOpRuns(func);
|
||||
compactRowWiseWvmmRuns(func);
|
||||
compactScalarChannelRuns(func, nextChannelId);
|
||||
compactBatchChannelRuns(func);
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
func.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -41,133 +42,75 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
if (values.empty() || !values.front().hasOneUse())
|
||||
return {};
|
||||
|
||||
static Value
|
||||
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(chunkType.getRank());
|
||||
sizes.reserve(chunkType.getRank());
|
||||
strides.reserve(chunkType.getRank());
|
||||
OpOperand& firstUse = *values.front().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(firstUse.getOwner());
|
||||
if (!concatOp)
|
||||
return {};
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
startOperandIndex = firstUse.getOperandNumber();
|
||||
for (auto [index, value] : llvm::enumerate(values)) {
|
||||
if (!value.hasOneUse())
|
||||
return {};
|
||||
OpOperand& use = *value.getUses().begin();
|
||||
if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index)
|
||||
return {};
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
||||
return concatOp;
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp,
|
||||
unsigned startOperandIndex,
|
||||
unsigned operandCount,
|
||||
Value packedValue,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<Value> newInputs;
|
||||
newInputs.reserve(concatOp.getInputs().size() - operandCount + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == startOperandIndex)
|
||||
newInputs.push_back(packedValue);
|
||||
if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount)
|
||||
newInputs.push_back(operand);
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) {
|
||||
rewriter.replaceOp(concatOp, newInputs.front());
|
||||
return;
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); });
|
||||
}
|
||||
|
||||
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
static RankedTensorType
|
||||
getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) {
|
||||
auto firstType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex].getType());
|
||||
if (!firstType || !firstType.hasStaticShape())
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
int64_t axis = concatOp.getAxis();
|
||||
if (axis < 0 || axis >= firstType.getRank())
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
SmallVector<int64_t> shape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
shape[axis] = 0;
|
||||
for (unsigned index = 0; index < operandCount; ++index) {
|
||||
auto operandType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex + index].getType());
|
||||
if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank())
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
for (int64_t dim = 0; dim < firstType.getRank(); ++dim) {
|
||||
if (dim == axis)
|
||||
continue;
|
||||
if (operandType.getShape()[dim] != shape[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
shape[axis] += operandType.getShape()[axis];
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
return RankedTensorType::get(shape, firstType.getElementType());
|
||||
}
|
||||
|
||||
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||
@@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
||||
return {};
|
||||
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
|
||||
return {};
|
||||
|
||||
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
|
||||
return {};
|
||||
}
|
||||
|
||||
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||
@@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(replacement);
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (const RegularChunk& chunk : run)
|
||||
outputs.push_back(chunk.output);
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||
auto concatPackedType = concatOp
|
||||
? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
|
||||
if (concatOp && concatPackedType == packedOutputType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), loop.getResult(0), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||
Value replacement = extractPackedChunk(
|
||||
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
|
||||
Value output = chunk.output;
|
||||
output.replaceAllUsesWith(replacement);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
@@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
SmallVector<Value> sortedOutputs;
|
||||
sortedOutputs.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries)
|
||||
sortedOutputs.push_back(entry.op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex);
|
||||
auto concatPackedType =
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
@@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(concatOp,
|
||||
concatStartIndex,
|
||||
static_cast<unsigned>(sortedOutputs.size()),
|
||||
compactReceive.getOutput(),
|
||||
rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (auto op : run)
|
||||
outputs.push_back(op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
|
||||
auto concatPackedType =
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
@@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
if (concatOp && concatPackedType) {
|
||||
replaceConcatRunWithPackedValue(
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user