compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled

better IR compaction after dcp merge
remove pim.mvm op
better memory report
This commit is contained in:
NiccoloN
2026-05-12 13:35:25 +02:00
parent 80a7298552
commit 628dc630a4
15 changed files with 419 additions and 305 deletions
+8 -1
View File
@@ -88,7 +88,14 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
+65 -22
View File
@@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : globalMemEntriesMap) {
@@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() {
hostReportRow = hostMem.getReportRow();
}
void PimAcceleratorMemory::reportCore(size_t coreId) {
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()});
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = row;
reportEntries.push_back(std::move(entry));
}
void PimAcceleratorMemory::flushReport() {
@@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() {
printMemoryReportRow(os, *hostReportRow);
}
if (!coreReportRows.empty()) {
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) {
const MemoryReportRow& lhsRow = lhs.second;
const MemoryReportRow& rhsRow = rhs.second;
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
if (lhs.kind != rhs.kind)
return lhs.kind == MemoryReportEntry::Kind::Batch;
const MemoryReportRow& lhsRow = lhs.row;
const MemoryReportRow& rhsRow = rhs.row;
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
if (lhsRow.numAlloca != rhsRow.numAlloca)
@@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() {
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
if (lhsRow.numGlobal != rhsRow.numGlobal)
return lhsRow.numGlobal > rhsRow.numGlobal;
return lhs.first < rhs.first;
return lhs.id < rhs.id;
});
for (size_t index = 0; index < coreReportRows.size();) {
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second)
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row) {
++runEnd;
}
llvm::SmallVector<size_t, 8> coreIds;
coreIds.reserve(runEnd - index);
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(coreReportRows[coreIndex].first);
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds));
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
printMemoryReportRow(os, coreReportRows[index].second);
if (runEnd < coreReportRows.size())
printMemoryReportRow(os, reportEntries[index].row);
if (runEnd < reportEntries.size())
os << "\n";
index = runEnd;
@@ -678,7 +711,6 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
indices.push_back(weightIndex);
};
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
@@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
// This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
@@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
memory.reportCore(coreId);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
@@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false))
return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
.getReportRow());
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
laneResult = emitCore(coreOp, true);
if (laneResult == CompilerSuccess)
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
}
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
}
memory.flushReport();
+15 -2
View File
@@ -33,6 +33,18 @@ struct MemoryReportRow {
}
};
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
@@ -66,7 +78,7 @@ private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
public:
PimAcceleratorMemory()
@@ -86,7 +98,8 @@ public:
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void reportCore(size_t coreId);
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
void flushReport();
void clean(mlir::Operation* op);
};
-1
View File
@@ -103,7 +103,6 @@ SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
indices.push_back(weightIndex);
};
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
@@ -7,7 +7,6 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
Value input = mapper.lookup(sendTensorBatchOp.getInput());
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput =
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
input = packedInput;
pim::PimSendTensorBatchOp::create(
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
@@ -21,12 +21,6 @@ def spatToPimVMM : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVM : Pat<
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVVAdd : Pat<
(SpatVAddOp:$srcOpRes $a, $b),
(PimVVAddOp $a, $b,
@@ -11,7 +11,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
rewriter.setInsertionPoint(sendTensorOp);
Value input = sendTensorOp.getInput();
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
input = packedInput;
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
PimSendTensorOp::create(
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(sendTensorOp);
}
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
rewriter.replaceOp(extractRowsOp, replacements);
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
}
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
populateTensorPackingPatterns(tensorPackingPatterns);
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
@@ -3,26 +3,6 @@
using namespace mlir;
namespace onnx_mlir {
namespace {
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
if (op.getDim() != 0)
return failure();
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
if (!packed)
return failure();
rewriter.replaceOp(op, packed);
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
@@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count
return RankedTensorType::get(packedShape, elementType.getElementType());
}
Value extractPackedChunk(
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
if (packedType && packedType == chunkType && index == 0)
return packedValue;
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
}
Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(builder.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(builder.getIndexAttr(1));
}
bool coversWholeSource = packedType == inputType && startIndex == 0;
if (coversWholeSource)
return extractRowsOp.getInput();
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
}
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
if (values.empty())
return {};
@@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -3,11 +3,21 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType,
unsigned index,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp,
unsigned startIndex,
unsigned count,
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
-24
View File
@@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
}];
}
def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> {
let summary = "Matrix-vector multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> {
let summary = "Element-wise addition: c = a + b";
@@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
}
};
struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface, PimMVMOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success();
}
};
template <typename OpTy>
struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
@@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
PimMVMOp::attachInterface<MVMOpInterface>(*ctx);
PimVVAddOp::attachInterface<BinaryDstOpInterface<PimVVAddOp>>(*ctx);
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
+4 -16
View File
@@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
);
let hasVerifier = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
@@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
);
let hasVerifier = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
let hasCustomAssemblyFormat = 1;
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
@@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
);
let hasVerifier = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
let hasCustomAssemblyFormat = 1;
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
@@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
);
let hasVerifier = 1;
let assemblyFormat = [{
attr-dict `:` type($output)
}];
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
+113
View File
@@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
template <typename TensorSendOpTy>
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
printer << " ";
printer.printOperand(op.getInput());
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
}
template <typename TensorReceiveOpTy>
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getOutput().getType());
}
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
@@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
return success();
}
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
} // namespace spatial
} // namespace onnx_mlir
@@ -1436,6 +1436,21 @@ public:
compactBatchChannelRuns(func);
compactRegularOpRuns(func);
compactRowWiseWvmmRuns(func);
compactScalarChannelRuns(func, nextChannelId);
compactBatchChannelRuns(func);
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
func.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
if (!sortTopologically(&func.getBody().front())) {
func.emitOpError("failed to topologically order merged Spatial IR");
signalPassFailure();
@@ -13,6 +13,7 @@
#include <tuple>
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -41,133 +42,75 @@ struct RegularChunk {
Value output;
};
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
packedShape[0] *= count;
return RankedTensorType::get(packedShape, elementType.getElementType());
}
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
if (values.empty() || !values.front().hasOneUse())
return {};
static Value
extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(chunkType.getRank());
sizes.reserve(chunkType.getRank());
strides.reserve(chunkType.getRank());
OpOperand& firstUse = *values.front().getUses().begin();
auto concatOp = dyn_cast<spatial::SpatConcatOp>(firstUse.getOwner());
if (!concatOp)
return {};
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0)));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
startOperandIndex = firstUse.getOperandNumber();
for (auto [index, value] : llvm::enumerate(values)) {
if (!value.hasOneUse())
return {};
OpOperand& use = *value.getUses().begin();
if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index)
return {};
}
return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
return concatOp;
}
static Value createPackedExtractRowsSlice(
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
return {};
int64_t rowsPerValue = rowType.getDimSize(0);
if (ShapedType::isDynamic(rowsPerValue))
return {};
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(inputType.getRank());
sizes.reserve(inputType.getRank());
strides.reserve(inputType.getRank());
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
strides.push_back(rewriter.getIndexAttr(1));
static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp,
unsigned startOperandIndex,
unsigned operandCount,
Value packedValue,
IRRewriter& rewriter) {
SmallVector<Value> newInputs;
newInputs.reserve(concatOp.getInputs().size() - operandCount + 1);
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (operandIndex == startOperandIndex)
newInputs.push_back(packedValue);
if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount)
newInputs.push_back(operand);
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
.getResult();
if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) {
rewriter.replaceOp(concatOp, newInputs.front());
return;
}
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); });
}
static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) {
if (values.empty())
return {};
if (values.size() == 1)
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
static RankedTensorType
getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) {
auto firstType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex].getType());
if (!firstType || !firstType.hasStaticShape())
return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
int64_t axis = concatOp.getAxis();
if (axis < 0 || axis >= firstType.getRank())
return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
SmallVector<int64_t> shape(firstType.getShape().begin(), firstType.getShape().end());
shape[axis] = 0;
for (unsigned index = 0; index < operandCount; ++index) {
auto operandType = dyn_cast<RankedTensorType>(concatOp.getInputs()[startOperandIndex + index].getType());
if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank())
return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
return {};
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
for (int64_t dim = 0; dim < firstType.getRank(); ++dim) {
if (dim == axis)
continue;
if (operandType.getShape()[dim] != shape[dim])
return {};
}
shape[axis] += operandType.getShape()[axis];
}
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(rewriter.getIndexAttr(firstOffsets[0]));
sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(rewriter.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim]));
sizes.push_back(rewriter.getIndexAttr(firstSizes[dim]));
strides.push_back(rewriter.getIndexAttr(firstStrides[dim]));
}
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
return RankedTensorType::get(shape, firstType.getElementType());
}
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
@@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
return {};
if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; }))
return {};
return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult();
return {};
}
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
@@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult());
}
for (auto [index, chunk] : llvm::enumerate(run)) {
Value replacement = extractPackedChunk(
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
Value output = chunk.output;
output.replaceAllUsesWith(replacement);
SmallVector<Value> outputs;
outputs.reserve(run.size());
for (const RegularChunk& chunk : run)
outputs.push_back(chunk.output);
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
auto concatPackedType = concatOp
? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {};
if (concatOp && concatPackedType == packedOutputType) {
replaceConcatRunWithPackedValue(
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), loop.getResult(0), rewriter);
}
else {
for (auto [index, chunk] : llvm::enumerate(run)) {
Value replacement = extractPackedChunk(
loop.getResult(0), outputType, static_cast<unsigned>(index), rewriter, chunk.startOp->getLoc());
Value output = chunk.output;
output.replaceAllUsesWith(replacement);
}
}
SmallVector<Operation*> opsToErase;
@@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
SmallVector<Value> sortedOutputs;
sortedOutputs.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries)
sortedOutputs.push_back(entry.op.getOutput());
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex);
auto concatPackedType =
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter,
@@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue(concatOp,
concatStartIndex,
static_cast<unsigned>(sortedOutputs.size()),
compactReceive.getOutput(),
rewriter);
}
else {
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
}
for (auto op : run)
rewriter.eraseOp(op);
@@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
SmallVector<Value> outputs;
outputs.reserve(run.size());
for (auto op : run)
outputs.push_back(op.getOutput());
unsigned concatStartIndex = 0;
auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex);
auto concatPackedType =
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
@@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue(
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
}
else {
for (auto [index, op] : llvm::enumerate(run))
op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
}
for (auto op : run)
rewriter.eraseOp(op);