compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -11,7 +11,6 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(sendTensorOp);
|
||||
Value input = sendTensorOp.getInput();
|
||||
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
|
||||
if (concatOp.getDim() == 0)
|
||||
if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc()))
|
||||
input = packedInput;
|
||||
PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
PimSendTensorOp::create(
|
||||
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(sendTensorOp);
|
||||
}
|
||||
|
||||
@@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(inputType.getRank());
|
||||
sizes.reserve(inputType.getRank());
|
||||
strides.reserve(inputType.getRank());
|
||||
|
||||
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
@@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
|
||||
RewritePatternSet tensorPackingPatterns(funcOp.getContext());
|
||||
populateTensorPackingPatterns(tensorPackingPatterns);
|
||||
(void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns));
|
||||
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
|
||||
Reference in New Issue
Block a user