better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -10,8 +10,6 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -68,14 +66,6 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
|
|||||||
return permutation;
|
return permutation;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value transposeMaybeInCompute(
|
|
||||||
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
|
||||||
auto buildTranspose = [&](Value input) -> Value {
|
|
||||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
|
||||||
};
|
|
||||||
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,12 +78,6 @@ llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation
|
|||||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||||
int64_t rank);
|
int64_t rank);
|
||||||
|
|
||||||
mlir::Value transposeMaybeInCompute(mlir::Value value,
|
|
||||||
mlir::RankedTensorType resultType,
|
|
||||||
mlir::ArrayRef<int64_t> permutation,
|
|
||||||
mlir::PatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
@@ -114,21 +113,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet matmulPatterns(ctx);
|
|
||||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
|
||||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
|
||||||
|
|
||||||
bool hasUnloweredMatMul = false;
|
|
||||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
|
||||||
hasUnloweredMatMul = true;
|
|
||||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
|
||||||
});
|
|
||||||
if (hasUnloweredMatMul) {
|
|
||||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<spatial::SpatialDialect,
|
target.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -165,10 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet transposePatterns(ctx);
|
|
||||||
populateTransposePatterns(transposePatterns, ctx);
|
|
||||||
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
|
|
||||||
|
|
||||||
ConversionTarget earlyPostTarget(*ctx);
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -206,15 +186,14 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
|
||||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
populateEmptyFunction(*entryFunc);
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
|
|
||||||
|
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||||
|
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { popula
|
|||||||
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
populateGeneratedConversionPatterns(patterns, ctx);
|
populateGeneratedConversionPatterns(patterns, ctx);
|
||||||
populateElementwisePatterns(patterns, ctx);
|
populateElementwisePatterns(patterns, ctx);
|
||||||
|
populateMatMulRewritePatterns(patterns, ctx);
|
||||||
populateGemmPatterns(patterns, ctx);
|
populateGemmPatterns(patterns, ctx);
|
||||||
populateConvPatterns(patterns, ctx);
|
populateConvPatterns(patterns, ctx);
|
||||||
populatePoolPatterns(patterns, ctx);
|
populatePoolPatterns(patterns, ctx);
|
||||||
|
|||||||
@@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn(
|
|||||||
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value extractTransposedBRow(
|
|
||||||
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
|
||||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value extractDynamicGemmRowVector(
|
static Value extractDynamicGemmRowVector(
|
||||||
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||||
@@ -424,7 +416,6 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
|||||||
RankedTensorType bType,
|
RankedTensorType bType,
|
||||||
RankedTensorType scalarPiecesType,
|
RankedTensorType scalarPiecesType,
|
||||||
RankedTensorType outType,
|
RankedTensorType outType,
|
||||||
bool bAlreadyTransposed,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t numOutRows = outType.getDimSize(0);
|
const int64_t numOutRows = outType.getDimSize(0);
|
||||||
@@ -446,8 +437,7 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
|||||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||||
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
|
||||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
|
|
||||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
@@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
const int64_t numOutCols = outType.getDimSize(1);
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
const int64_t reductionSize = aType.getDimSize(1);
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
|
|
||||||
|
if (gemmOpAdaptor.getTransB()) {
|
||||||
|
auto bShape = bType.getShape();
|
||||||
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
|
||||||
|
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||||
|
bType = transposedType;
|
||||||
|
}
|
||||||
|
|
||||||
if (!isCompileTimeComputable(b)) {
|
if (!isCompileTimeComputable(b)) {
|
||||||
bool hasC = hasGemmBias(c);
|
bool hasC = hasGemmBias(c);
|
||||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||||
@@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
biasType = *verifiedBiasType;
|
biasType = *verifiedBiasType;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|
||||||
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
|
|| bType.getDimSize(1) != numOutCols) {
|
||||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|
|
||||||
|| bType.getDimSize(1) != expectedBCols) {
|
|
||||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||||
auto batchOp =
|
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
|
||||||
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
|
||||||
if (failed(batchOp))
|
if (failed(batchOp))
|
||||||
return failure();
|
return failure();
|
||||||
auto outputCompute = createDynamicGemmOutputCompute(
|
auto outputCompute = createDynamicGemmOutputCompute(
|
||||||
@@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
b = *scaledB;
|
b = *scaledB;
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
|
||||||
if (gemmOpAdaptor.getTransB()) {
|
|
||||||
auto bShape = bType.getShape();
|
|
||||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
|
||||||
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
|
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
||||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value,
|
|||||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
auto shape = type.getShape();
|
auto shape = type.getShape();
|
||||||
RankedTensorType transposedType;
|
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
|
||||||
SmallVector<int64_t> perm;
|
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||||
|
};
|
||||||
if (type.getRank() == 2) {
|
if (type.getRank() == 2) {
|
||||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
|
||||||
perm = {1, 0};
|
return createONNXTranspose(resultType, {1, 0});
|
||||||
}
|
|
||||||
else {
|
|
||||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
||||||
perm = {0, 2, 1};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
|
||||||
}
|
return createONNXTranspose(resultType, {0, 2, 1});
|
||||||
|
|
||||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
|
||||||
auto shape = type.getShape();
|
|
||||||
RankedTensorType transposedType;
|
|
||||||
SmallVector<int64_t> perm;
|
|
||||||
if (type.getRank() == 2) {
|
|
||||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
||||||
perm = {1, 0};
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
||||||
perm = {0, 2, 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
||||||
});
|
|
||||||
return transposeCompute.getResult(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||||
@@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix,
|
|||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value extractDynamicBatchedBRow(Value matrix,
|
|
||||||
int64_t sourceBatchCount,
|
|
||||||
Value batch,
|
|
||||||
Value row,
|
|
||||||
RankedTensorType vectorType,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
|
||||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
|
||||||
: OpFoldResult(batch),
|
|
||||||
row,
|
|
||||||
rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
|
||||||
auto rowSlice =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
|
||||||
return tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
vectorType,
|
|
||||||
rowSlice,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||||
int64_t sourceBatchCount,
|
int64_t sourceBatchCount,
|
||||||
Value batch,
|
Value batch,
|
||||||
@@ -432,7 +383,6 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
|||||||
RankedTensorType bType,
|
RankedTensorType bType,
|
||||||
RankedTensorType scalarPiecesType,
|
RankedTensorType scalarPiecesType,
|
||||||
RankedTensorType outType,
|
RankedTensorType outType,
|
||||||
bool bAlreadyTransposed,
|
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t numBatches = outType.getDimSize(0);
|
const int64_t numBatches = outType.getDimSize(0);
|
||||||
@@ -459,9 +409,7 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
|||||||
Value aVector =
|
Value aVector =
|
||||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||||
Value bVector =
|
Value bVector =
|
||||||
bAlreadyTransposed
|
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||||
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
|
||||||
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
|
||||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
@@ -537,15 +485,6 @@ static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
|
|||||||
return computeOp->getResult(0);
|
return computeOp->getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
|
||||||
auto transposeCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
||||||
});
|
|
||||||
return transposeCompute.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||||
Value batch,
|
Value batch,
|
||||||
Value hSlice,
|
Value hSlice,
|
||||||
@@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
int64_t gemmK = shapeInfo->k;
|
int64_t gemmK = shapeInfo->k;
|
||||||
int64_t gemmN = shapeInfo->n;
|
int64_t gemmN = shapeInfo->n;
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm) {
|
||||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
@@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
rewriter.getBoolAttr(false),
|
rewriter.getBoolAttr(false),
|
||||||
rewriter.getBoolAttr(false))
|
rewriter.getBoolAttr(false))
|
||||||
.getY();
|
.getY();
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm)
|
||||||
auto transposeCompute =
|
gemmResult =
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||||
Value transposed =
|
.getResult();
|
||||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
||||||
});
|
|
||||||
gemmResult = transposeCompute.getResult(0);
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(matmulOp, gemmResult);
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
int64_t gemmK = shapeInfo->k;
|
int64_t gemmK = shapeInfo->k;
|
||||||
int64_t gemmN = shapeInfo->n;
|
int64_t gemmN = shapeInfo->n;
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm) {
|
||||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
@@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (failed(result))
|
if (failed(result))
|
||||||
return failure();
|
return failure();
|
||||||
Value finalResult = *result;
|
Value finalResult = *result;
|
||||||
if (useTransposedForm)
|
if (useTransposedForm) {
|
||||||
finalResult = transposeBatchedOutput(
|
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||||
finalResult,
|
shapeInfo->outType.getElementType(),
|
||||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
shapeInfo->outType.getEncoding());
|
||||||
rewriter,
|
finalResult =
|
||||||
loc);
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, finalResult);
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
return success();
|
return success();
|
||||||
@@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
rhsBatchedType,
|
rhsBatchedType,
|
||||||
scalarPiecesType,
|
scalarPiecesType,
|
||||||
directOutType,
|
directOutType,
|
||||||
false,
|
|
||||||
rewriter,
|
rewriter,
|
||||||
loc);
|
loc);
|
||||||
if (failed(batchOp))
|
if (failed(batchOp))
|
||||||
@@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (failed(result))
|
if (failed(result))
|
||||||
return failure();
|
return failure();
|
||||||
Value finalResult = *result;
|
Value finalResult = *result;
|
||||||
if (useTransposedForm)
|
if (useTransposedForm) {
|
||||||
finalResult = transposeBatchedOutput(
|
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||||
finalResult,
|
shapeInfo->outType.getElementType(),
|
||||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
shapeInfo->outType.getEncoding());
|
||||||
rewriter,
|
finalResult =
|
||||||
loc);
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, finalResult);
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -137,11 +137,17 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
|
|
||||||
auto transposedType = RankedTensorType::get(
|
auto transposedType = RankedTensorType::get(
|
||||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||||
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
Value transposedInput =
|
||||||
|
ONNXTransposeOp::create(
|
||||||
|
rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation))
|
||||||
|
.getResult();
|
||||||
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||||
if (failed(transposedResult))
|
if (failed(transposedResult))
|
||||||
return failure();
|
return failure();
|
||||||
result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
result =
|
||||||
|
ONNXTransposeOp::create(
|
||||||
|
rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation))
|
||||||
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(softmaxOp, result);
|
rewriter.replaceOp(softmaxOp, result);
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static bool isInsideSpatialComputeRegion(Operation* op) {
|
||||||
|
return op->getParentOfType<spatial::SpatCompute>() || op->getParentOfType<spatial::SpatComputeBatch>();
|
||||||
|
}
|
||||||
|
|
||||||
static Value createTransposeInit(Value input,
|
static Value createTransposeInit(Value input,
|
||||||
RankedTensorType resultType,
|
RankedTensorType resultType,
|
||||||
ArrayRef<int64_t> permutation,
|
ArrayRef<int64_t> permutation,
|
||||||
@@ -102,10 +106,22 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
|
||||||
Value transposed =
|
auto buildTranspose = [&](Value input) -> Value {
|
||||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
|
Value init = createTransposeInit(input, resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||||
rewriter.replaceOp(transposeOp, transposed);
|
return linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), input, init, *permutation).getResult()[0];
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isInsideSpatialComputeRegion(transposeOp.getOperation())) {
|
||||||
|
rewriter.replaceOp(transposeOp, buildTranspose(adaptor.getData()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {adaptor.getData()}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), buildTranspose(input));
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user