compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled

better IR compaction after dcp merge
remove pim.mvm op
better memory report
This commit is contained in:
NiccoloN
2026-05-12 13:35:25 +02:00
parent 80a7298552
commit 628dc630a4
15 changed files with 419 additions and 305 deletions
+8 -1
View File
@@ -88,7 +88,14 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) { void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op"); assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); }); root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) { root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights(); auto weights = coreBatchOp.getWeights();
for (auto weight : weights) for (auto weight : weights)
+65 -22
View File
@@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n"; os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
} }
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const { MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row; MemoryReportRow row;
for (auto& [val, memEntry] : globalMemEntriesMap) { for (auto& [val, memEntry] : globalMemEntriesMap) {
@@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() {
hostReportRow = hostMem.getReportRow(); hostReportRow = hostMem.getReportRow();
} }
void PimAcceleratorMemory::reportCore(size_t coreId) { void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()}); reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = row;
reportEntries.push_back(std::move(entry));
} }
void PimAcceleratorMemory::flushReport() { void PimAcceleratorMemory::flushReport() {
@@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() {
printMemoryReportRow(os, *hostReportRow); printMemoryReportRow(os, *hostReportRow);
} }
if (!coreReportRows.empty()) { if (!reportEntries.empty()) {
if (hostReportRow.has_value()) if (hostReportRow.has_value())
os << "\n"; os << "\n";
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) { llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
const MemoryReportRow& lhsRow = lhs.second; if (lhs.kind != rhs.kind)
const MemoryReportRow& rhsRow = rhs.second; return lhs.kind == MemoryReportEntry::Kind::Batch;
const MemoryReportRow& lhsRow = lhs.row;
const MemoryReportRow& rhsRow = rhs.row;
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca) if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
return lhsRow.sizeAlloca > rhsRow.sizeAlloca; return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
if (lhsRow.numAlloca != rhsRow.numAlloca) if (lhsRow.numAlloca != rhsRow.numAlloca)
@@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() {
return lhsRow.sizeGlobal > rhsRow.sizeGlobal; return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
if (lhsRow.numGlobal != rhsRow.numGlobal) if (lhsRow.numGlobal != rhsRow.numGlobal)
return lhsRow.numGlobal > rhsRow.numGlobal; return lhsRow.numGlobal > rhsRow.numGlobal;
return lhs.first < rhs.first; return lhs.id < rhs.id;
}); });
for (size_t index = 0; index < coreReportRows.size();) { for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1; size_t runEnd = index + 1;
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second) while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row) {
++runEnd; ++runEnd;
}
llvm::SmallVector<size_t, 8> coreIds; if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
coreIds.reserve(runEnd - index); os << "Batch ";
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex) for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
coreIds.push_back(coreReportRows[coreIndex].first); if (batchIndex != index)
os << ",\n ";
os << "Core "; os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds)); printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n"; os << ":\n";
printMemoryReportRow(os, coreReportRows[index].second); printMemoryReportRow(os, reportEntries[index].row);
if (runEnd < coreReportRows.size()) if (runEnd < reportEntries.size())
os << "\n"; os << "\n";
index = runEnd; index = runEnd;
@@ -678,7 +711,6 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
indices.push_back(weightIndex); indices.push_back(weightIndex);
}; };
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices); llvm::sort(indices);
return indices; return indices;
@@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenConcatOp(concatOp, knowledge); coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge); coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op)) else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
// This implementation always assigns one crossbar per group. // This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup; json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0; size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder // Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
@@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds); PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory); aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
memory.reportCore(coreId);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0) if (processedOperations < 0)
@@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) { if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false)) if (auto err = emitCore(coreOp, false))
return err; return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
.getReportRow());
continue; continue;
} }
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op); auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) { for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess; OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) { if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
laneResult = emitCore(coreOp, true); laneResult = emitCore(coreOp, true);
if (laneResult == CompilerSuccess)
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
return laneResult == CompilerSuccess ? success() : failure(); return laneResult == CompilerSuccess ? success() : failure();
}))) })))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult; return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
} }
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
} }
memory.flushReport(); memory.flushReport();
+15 -2
View File
@@ -33,6 +33,18 @@ struct MemoryReportRow {
} }
}; };
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
};
class PimMemory { class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries; llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
@@ -66,7 +78,7 @@ private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem; llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport; std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow; std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows; llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
@@ -86,7 +98,8 @@ public:
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost(); void reportHost();
void reportCore(size_t coreId); void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
void flushReport(); void flushReport();
void clean(mlir::Operation* op); void clean(mlir::Operation* op);
}; };
-1
View File
@@ -103,7 +103,6 @@ SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
indices.push_back(weightIndex); indices.push_back(weightIndex);
}; };
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices); llvm::sort(indices);
return indices; return indices;
@@ -7,7 +7,6 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds()) for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
Value input = mapper.lookup(sendTensorBatchOp.getInput()); pim::PimSendTensorBatchOp::create(rewriter,
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>()) sendTensorBatchOp.getLoc(),
if (concatOp.getDim() == 0) mapper.lookup(sendTensorBatchOp.getInput()),
if (Value packedInput = rewriter.getDenseI32ArrayAttr(targetCoreIds));
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
input = packedInput;
pim::PimSendTensorBatchOp::create(
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
} }
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
@@ -21,12 +21,6 @@ def spatToPimVMM : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
def spatToPimMVM : Pat<
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAdd : Pat< def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b), (SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b, (PimVVAddOp $a, $b,
@@ -11,7 +11,6 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp); rewriter.setInsertionPoint(sendTensorOp);
Value input = sendTensorOp.getInput(); PimSendTensorOp::create(
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>()) rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
if (concatOp.getDim() == 0)
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
input = packedInput;
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp); rewriter.eraseOp(sendTensorOp);
} }
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
rewriter.replaceOp(extractRowsOp, replacements); rewriter.replaceOp(extractRowsOp, replacements);
} }
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps; SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
.getResult()); .getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput()); rewriter.replaceOp(concatOp, newConcat.getOutput());
} }
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
populateTensorPackingPatterns(tensorPackingPatterns);
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
auto eraseUnusedOps = [&](auto tag) { auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag); using OpTy = decltype(tag);
SmallVector<OpTy> ops; SmallVector<OpTy> ops;
@@ -3,26 +3,6 @@
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
if (op.getDim() != 0)
return failure();
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
if (!packed)
return failure();
rewriter.replaceOp(op, packed);
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end()); SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
@@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count
return RankedTensorType::get(packedShape, elementType.getElementType()); return RankedTensorType::get(packedShape, elementType.getElementType());
} }
Value extractPackedChunk(
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
if (packedType && packedType == chunkType && index == 0)
return packedValue;
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
}
Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
bool coversWholeSource = packedType == inputType && startIndex == 0;
if (coversWholeSource)
return extractRowsOp.getInput();
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) { Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
if (values.empty()) if (values.empty())
return {}; return {};
@@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult(); .getResult();
} }
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,11 +3,21 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count); mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType,
unsigned index,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp,
unsigned startIndex,
unsigned count,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc); mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir } // namespace onnx_mlir
-24
View File
@@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
let summary = "Matrix-vector multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> { def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
let summary = "Element-wise addition: c = a + b"; let summary = "Element-wise addition: c = a + b";
@@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
} }
}; };
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success();
}
};
template <typename OpTy> template <typename OpTy>
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> { struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
@@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpInterface>(*ctx); PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx); PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx); PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
+4 -16
View File
@@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
@@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
} }
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
@@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
@@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
+113
View File
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value); return parser.getBuilder().getI32IntegerAttr(value);
} }
template <typename TensorSendOpTy>
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
printer << " ";
printer.printOperand(op.getInput());
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
}
template <typename TensorReceiveOpTy>
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getOutput().getType());
}
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
} // namespace } // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) { void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return parser.parseRegion(*body, regionArgs); return parser.parseRegion(*body, regionArgs);
} }
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printer.printOperand(getInput()); printer.printOperand(getInput());
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
return parser.resolveOperand(input, inputType, result.operands); return parser.resolveOperand(input, inputType, result.operands);
} }
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict( printer.printOptionalAttrDict(
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
return success(); return success();
} }
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1436,6 +1436,21 @@ public:
compactBatchChannelRuns(func); compactBatchChannelRuns(func);
compactRegularOpRuns(func); compactRegularOpRuns(func);
compactRowWiseWvmmRuns(func); compactRowWiseWvmmRuns(func);
compactScalarChannelRuns(func, nextChannelId);
compactBatchChannelRuns(func);
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
func.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
if (!sortTopologically(&func.getBody().front())) { if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR"); func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure(); signalPassFailure();
@@ -13,6 +13,7 @@
#include <tuple> #include <tuple>
#include "RegularOpCompaction.hpp" #include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -41,133 +42,75 @@ struct RegularChunk {
Value output; Value output;
}; };
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end()); if (values.empty() || !values.front().hasOneUse())
packedShape[0] *= count; return {};
return RankedTensorType::get(packedShape, elementType.getElementType());
}
static Value OpOperand& firstUse = *values.front().getUses().begin();
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) { auto concatOp = dyn_cast<spatial::SpatConcatOp>(firstUse.getOwner());
SmallVector<OpFoldResult> offsets; if (!concatOp)
SmallVector<OpFoldResult> sizes; return {};
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0))); startOperandIndex = firstUse.getOperandNumber();
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0))); for (auto [index, value] : llvm::enumerate(values)) {
strides.push_back(rewriter.getIndexAttr(1)); if (!value.hasOneUse())
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) { return {};
offsets.push_back(rewriter.getIndexAttr(0)); OpOperand& use = *value.getUses().begin();
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim))); if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index)
strides.push_back(rewriter.getIndexAttr(1)); return {};
} }
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult(); return concatOp;
} }
static Value createPackedExtractRowsSlice( static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp,
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { unsigned startOperandIndex,
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType()); unsigned operandCount,
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType()); Value packedValue,
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) IRRewriter& rewriter) {
return {}; SmallVector<Value> newInputs;
newInputs.reserve(concatOp.getInputs().size() - operandCount + 1);
int64_t rowsPerValue = rowType.getDimSize(0); for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (ShapedType::isDynamic(rowsPerValue)) if (operandIndex == startOperandIndex)
return {}; newInputs.push_back(packedValue);
if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount)
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count)); newInputs.push_back(operand);
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
} }
if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) {
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) rewriter.replaceOp(concatOp, newInputs.front());
.getResult(); return;
}
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); });
} }
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) { static RankedTensorType
if (values.empty()) getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) {
return {}; auto firstType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex].getType());
if (values.size() == 1) if (!firstType || !firstType.hasStaticShape())
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
return {}; return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType()); int64_t axis = concatOp.getAxis();
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType()); if (axis < 0 || axis >= firstType.getRank())
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
return {}; return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) { SmallVector<int64_t> shape(firstType.getShape().begin(), firstType.getShape().end());
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); shape[axis] = 0;
}; for (unsigned index = 0; index < operandCount; ++index) {
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes()) auto operandType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex + index].getType());
|| !hasStaticValues(firstSliceOp.getStaticStrides())) if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank())
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
return {}; return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides) for (int64_t dim = 0; dim < firstType.getRank(); ++dim) {
return {}; if (dim == axis)
continue;
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue) if (operandType.getShape()[dim] != shape[dim])
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
return {}; return {};
}
shape[axis] += operandType.getShape()[axis];
} }
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size())); return RankedTensorType::get(shape, firstType.getElementType());
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
} }
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
@@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
return {}; return {};
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; })) if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
return {}; return {};
return {};
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
} }
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
@@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult()); scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
} }
for (auto [index, chunk] : llvm::enumerate(run)) { SmallVector<Value> outputs;
Value replacement = extractPackedChunk( outputs.reserve(run.size());
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc()); for (const RegularChunk& chunk : run)
Value output = chunk.output; outputs.push_back(chunk.output);
output.replaceAllUsesWith(replacement);
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
auto concatPackedType = concatOp
? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {};
if (concatOp && concatPackedType == packedOutputType) {
replaceConcatRunWithPackedValue(
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), loop.getResult(0), rewriter);
}
else {
for (auto [index, chunk] : llvm::enumerate(run)) {
Value replacement = extractPackedChunk(
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
Value output = chunk.output;
output.replaceAllUsesWith(replacement);
}
} }
SmallVector<Operation*> opsToErase; SmallVector<Operation*> opsToErase;
@@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType()); auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size())); auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
SmallVector<Value> sortedOutputs;
sortedOutputs.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries)
sortedOutputs.push_back(entry.op.getOutput());
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex);
auto concatPackedType =
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
auto compactReceive = auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter, spatial::SpatChannelReceiveTensorOp::create(rewriter,
@@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) if (concatOp && concatPackedType) {
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( replaceConcatRunWithPackedValue(concatOp,
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc())); concatStartIndex,
static_cast<unsigned>(sortedOutputs.size()),
compactReceive.getOutput(),
rewriter);
}
else {
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
}
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
} }
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType()); auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size())); auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
SmallVector<Value> outputs;
outputs.reserve(run.size());
for (auto op : run)
outputs.push_back(op.getOutput());
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
auto concatPackedType =
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front()); rewriter.setInsertionPoint(run.front());
auto compactReceive = auto compactReceive =
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
@@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run)) if (concatOp && concatPackedType) {
op.getOutput().replaceAllUsesWith(extractPackedChunk( replaceConcatRunWithPackedValue(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc())); concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
}
else {
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
}
for (auto op : run) for (auto op : run)
rewriter.eraseOp(op); rewriter.eraseOp(op);