This commit is contained in:
@@ -10,15 +10,11 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
|
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
|
||||||
|
|
||||||
mlir::Value getOrCreateConstant(mlir::OperationFolder& folder,
|
mlir::Value
|
||||||
mlir::Operation* anchorOp,
|
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||||
mlir::Attribute value,
|
|
||||||
mlir::Type type);
|
|
||||||
|
|
||||||
mlir::Value getOrCreateConstant(mlir::RewriterBase& rewriter,
|
mlir::Value
|
||||||
mlir::Operation* anchorOp,
|
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||||
mlir::Attribute value,
|
|
||||||
mlir::Type type);
|
|
||||||
|
|
||||||
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "AttributeUtils.hpp"
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
|
||||||
|
#include "AttributeUtils.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
@@ -11,8 +12,6 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -209,8 +208,7 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
|||||||
block->getArgument(0),
|
block->getArgument(0),
|
||||||
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||||
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())
|
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
|
||||||
};
|
|
||||||
|
|
||||||
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
@@ -252,8 +250,8 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
|
|||||||
if (isCompileTimeComputable(input))
|
if (isCompileTimeComputable(input))
|
||||||
return buildFn(input);
|
return buildFn(input);
|
||||||
|
|
||||||
auto computeOp =
|
auto computeOp = createSpatCompute<1>(
|
||||||
createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||||
mlir::Value result = buildFn(computeInput);
|
mlir::Value result = buildFn(computeInput);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
#include "IndexingUtils.hpp"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "IndexingUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -85,7 +84,8 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
|
|||||||
|
|
||||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
||||||
if (auto attr = dyn_cast<Attribute>(value))
|
if (auto attr = dyn_cast<Attribute>(value))
|
||||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
return getOrCreateIndexConstant(
|
||||||
|
rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||||
return cast<Value>(value);
|
return cast<Value>(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
|
|||||||
mlir::AffineExpr expr,
|
mlir::AffineExpr expr,
|
||||||
mlir::ValueRange operands);
|
mlir::ValueRange operands);
|
||||||
|
|
||||||
mlir::Value
|
mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter,
|
||||||
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
|
mlir::Operation* anchorOp,
|
||||||
|
mlir::Value value,
|
||||||
|
int64_t multiplier);
|
||||||
|
|
||||||
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "ShapeTilingUtils.hpp"
|
|
||||||
#include "IndexingUtils.hpp"
|
#include "IndexingUtils.hpp"
|
||||||
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
@@ -53,7 +53,9 @@ bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
|
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||||
|
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||||
|
}
|
||||||
|
|
||||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||||
@@ -98,11 +100,8 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
|
|||||||
return permutation;
|
return permutation;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value transposeMaybeInCompute(Value value,
|
Value transposeMaybeInCompute(
|
||||||
RankedTensorType resultType,
|
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<int64_t> permutation,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto buildTranspose = [&](Value input) -> Value {
|
auto buildTranspose = [&](Value input) -> Value {
|
||||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||||
};
|
};
|
||||||
@@ -127,7 +126,8 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
|||||||
|
|
||||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()
|
||||||
|
|| sourceType.getRank() != resultType.getRank())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
for (OpFoldResult stride : strides) {
|
for (OpFoldResult stride : strides) {
|
||||||
@@ -290,7 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value lower = zeroIndices[dim];
|
Value lower = zeroIndices[dim];
|
||||||
Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
Value upper =
|
||||||
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
||||||
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
@@ -316,7 +317,8 @@ Value extractAxisSlice(
|
|||||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
return tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,12 +115,8 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
|||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
|
mlir::Value extractAxisSlice(
|
||||||
mlir::Location loc,
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
mlir::Value source,
|
|
||||||
int64_t axis,
|
|
||||||
int64_t offset,
|
|
||||||
int64_t size);
|
|
||||||
|
|
||||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
|
|||||||
@@ -15,8 +15,8 @@
|
|||||||
#include "Common/Common.hpp"
|
#include "Common/Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -69,9 +69,9 @@ LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
||||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex
|
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
||||||
<< (allowChannelReceiveInputs
|
<< kind << " input #" << currentInputIndex
|
||||||
? " must come from the host or an explicit "
|
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
||||||
"spat.channel_receive"
|
"spat.channel_receive"
|
||||||
: " must come from the host");
|
: " must come from the host");
|
||||||
if (definingOp)
|
if (definingOp)
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedPrePatterns(patterns, ctx); }
|
||||||
populateGeneratedPrePatterns(patterns, ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
populateGeneratedConversionPatterns(patterns, ctx);
|
populateGeneratedConversionPatterns(patterns, ctx);
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ static Value createPaddedRows(Value tensorValue,
|
|||||||
if (tensorType.getDimSize(0) == paddedRows)
|
if (tensorType.getDimSize(0) == paddedRows)
|
||||||
return tensorValue;
|
return tensorValue;
|
||||||
|
|
||||||
auto paddedType =
|
auto paddedType = RankedTensorType::get(
|
||||||
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||||
rewriter.getIndexAttr(0)};
|
rewriter.getIndexAttr(0)};
|
||||||
@@ -62,20 +62,15 @@ static Value createPaddedRows(Value tensorValue,
|
|||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
padOp.getRegion().push_back(padBlock);
|
padOp.getRegion().push_back(padBlock);
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
auto zero = getOrCreateConstant(rewriter,
|
auto zero = getOrCreateConstant(
|
||||||
padOp.getOperation(),
|
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
|
||||||
rewriter.getZeroAttr(tensorType.getElementType()),
|
|
||||||
tensorType.getElementType());
|
|
||||||
tensor::YieldOp::create(rewriter, loc, zero);
|
tensor::YieldOp::create(rewriter, loc, zero);
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
return padOp.getResult();
|
return padOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value packRowsForParallelGemm(Value rows,
|
static Value packRowsForParallelGemm(
|
||||||
RankedTensorType rowsType,
|
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
int64_t packFactor,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (packFactor == 1)
|
if (packFactor == 1)
|
||||||
return rows;
|
return rows;
|
||||||
|
|
||||||
@@ -118,10 +113,8 @@ static Value unpackRowsFromParallelGemm(Value packedRows,
|
|||||||
|
|
||||||
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||||
auto expandedType =
|
auto expandedType = RankedTensorType::get(
|
||||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
|
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||||
packedRowsType.getElementType(),
|
|
||||||
packedRowsType.getEncoding());
|
|
||||||
auto paddedType =
|
auto paddedType =
|
||||||
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||||
auto unpackedType =
|
auto unpackedType =
|
||||||
@@ -193,11 +186,8 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
|||||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createConvWeightMatrix(Value w,
|
static Value createConvWeightMatrix(
|
||||||
RankedTensorType wFlatType,
|
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
RankedTensorType wTransType,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto buildWeightMatrix = [&](Value weight) -> Value {
|
auto buildWeightMatrix = [&](Value weight) -> Value {
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
@@ -360,9 +350,8 @@ static Value createIm2colRowComputes(Value x,
|
|||||||
Value im2col = im2colLoop.getResult(0);
|
Value im2col = im2colLoop.getResult(0);
|
||||||
|
|
||||||
Value gemmInputRows = im2col;
|
Value gemmInputRows = im2col;
|
||||||
if (packFactor != 1) {
|
if (packFactor != 1)
|
||||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||||
}
|
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||||
});
|
});
|
||||||
@@ -387,8 +376,13 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
gemmOut = unpackRowsFromParallelGemm(
|
gemmOut = unpackRowsFromParallelGemm(packedOutput,
|
||||||
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
cast<RankedTensorType>(packedOutput.getType()),
|
||||||
|
numPatches,
|
||||||
|
numChannelsOut,
|
||||||
|
packFactor,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
// Restore to NCHW layout:
|
||||||
|
|||||||
@@ -252,7 +252,13 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||||
auto batchOp = createSpatComputeBatch(
|
auto batchOp = createSpatComputeBatch(
|
||||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {partialPiecesType},
|
||||||
|
laneCount,
|
||||||
|
ValueRange {b},
|
||||||
|
ValueRange {a},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||||
@@ -284,8 +290,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
return *batchOp;
|
return *batchOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createDynamicGemmBatchRow(
|
static Value
|
||||||
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
if (numOutCols == 1)
|
if (numOutCols == 1)
|
||||||
return lane;
|
return lane;
|
||||||
|
|
||||||
@@ -294,17 +300,21 @@ static Value createDynamicGemmBatchRow(
|
|||||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value extractDynamicGemmBColumn(
|
||||||
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
||||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||||
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
|
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
|
||||||
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}};
|
SmallVector<ReassociationIndices> collapseReassociation {
|
||||||
|
ReassociationIndices {0, 1}
|
||||||
|
};
|
||||||
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
Value collapsed =
|
Value collapsed =
|
||||||
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
|
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
|
||||||
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}};
|
SmallVector<ReassociationIndices> expandReassociation {
|
||||||
|
ReassociationIndices {0, 1}
|
||||||
|
};
|
||||||
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,13 +381,15 @@ static Value createBroadcastedBiasScalar(Value bias,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
|
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
|
||||||
if (biasType.getRank() == 1) {
|
if (biasType.getRank() == 1) {
|
||||||
SmallVector<OpFoldResult> offsets {
|
SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||||
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
|
: OpFoldResult(column)};
|
||||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
|
||||||
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
|
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
|
||||||
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides)
|
Value vector =
|
||||||
.getResult();
|
tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
|
||||||
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}};
|
SmallVector<ReassociationIndices> reassociation {
|
||||||
|
ReassociationIndices {0, 1}
|
||||||
|
};
|
||||||
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
|
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,15 +419,20 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
|||||||
const int64_t reductionSize = aType.getDimSize(1);
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
const int64_t laneCount = numOutRows * numOutCols;
|
const int64_t laneCount = numOutRows * numOutCols;
|
||||||
auto batchOp = createSpatComputeBatch(
|
auto batchOp = createSpatComputeBatch(
|
||||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {scalarPiecesType},
|
||||||
|
laneCount,
|
||||||
|
ValueRange {},
|
||||||
|
ValueRange {a, b},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
|
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
|
||||||
|
|
||||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||||
Value bVector = bAlreadyTransposed
|
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||||
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
|
||||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
|
|
||||||
@@ -578,9 +595,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
|||||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||||
Value reduced =
|
Value reduced =
|
||||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||||
Value hOffset =
|
Value hOffset = onnx_mlir::multiplyIndexByConstant(
|
||||||
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
|
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||||
crossbarSize.getValue());
|
|
||||||
if (biasArg) {
|
if (biasArg) {
|
||||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||||
Value biasSlice =
|
Value biasSlice =
|
||||||
@@ -721,8 +737,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||||
auto batchOp = createVvdmulBatch(
|
auto batchOp =
|
||||||
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||||
auto outputCompute = createDynamicGemmOutputCompute(
|
auto outputCompute = createDynamicGemmOutputCompute(
|
||||||
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
||||||
rewriter.replaceOp(gemmOp, outputCompute.getResults());
|
rewriter.replaceOp(gemmOp, outputCompute.getResults());
|
||||||
|
|||||||
@@ -70,11 +70,8 @@ static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
|
|||||||
return keptAxes;
|
return keptAxes;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value computeLaneIndex(Value lane,
|
static Value
|
||||||
int64_t stride,
|
computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
int64_t dimSize,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (dimSize == 1)
|
if (dimSize == 1)
|
||||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
|
||||||
@@ -119,8 +116,14 @@ static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
|||||||
sliceSizes.reserve(inputType.getRank());
|
sliceSizes.reserve(inputType.getRank());
|
||||||
insertOffsets.reserve(inputType.getRank());
|
insertOffsets.reserve(inputType.getRank());
|
||||||
|
|
||||||
auto batchOp = createSpatComputeBatch(
|
auto batchOp =
|
||||||
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
|
createSpatComputeBatch(rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {batchType},
|
||||||
|
laneCount,
|
||||||
|
{},
|
||||||
|
ValueRange {input},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
size_t keptAxisIndex = 0;
|
size_t keptAxisIndex = 0;
|
||||||
sliceOffsets.clear();
|
sliceOffsets.clear();
|
||||||
sliceSizes.clear();
|
sliceSizes.clear();
|
||||||
@@ -132,8 +135,8 @@ static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value axisIndex =
|
Value axisIndex = computeLaneIndex(
|
||||||
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
|
args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
|
||||||
++keptAxisIndex;
|
++keptAxisIndex;
|
||||||
sliceOffsets.push_back(axisIndex);
|
sliceOffsets.push_back(axisIndex);
|
||||||
sliceSizes.push_back(rewriter.getIndexAttr(1));
|
sliceSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
@@ -142,8 +145,8 @@ static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
|||||||
insertOffsets.push_back(args.lane);
|
insertOffsets.push_back(args.lane);
|
||||||
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
|
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
|
||||||
|
|
||||||
Value slice =
|
Value slice = tensor::ExtractSliceOp::create(
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
|
rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
|
||||||
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
|
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
|
||||||
createParallelInsertSliceIntoBatchOutput(
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
|
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
|
||||||
@@ -193,15 +196,15 @@ static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
|
|||||||
|
|
||||||
auto reshapeCompute =
|
auto reshapeCompute =
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
|
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
|
||||||
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
|
auto flatType =
|
||||||
|
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
|
||||||
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
|
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
|
||||||
Value compact = flat;
|
Value compact = flat;
|
||||||
if (compactKeptType != flatType)
|
if (compactKeptType != flatType)
|
||||||
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
|
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
|
||||||
Value keepdims = compact;
|
Value keepdims = compact;
|
||||||
if (keepdimsType != compactKeptType)
|
if (keepdimsType != compactKeptType)
|
||||||
keepdims =
|
keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
|
||||||
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
|
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
|
||||||
});
|
});
|
||||||
return reshapeCompute.getResult(0);
|
return reshapeCompute.getResult(0);
|
||||||
|
|||||||
@@ -121,11 +121,9 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
|
|
||||||
auto transposedType = RankedTensorType::get(
|
auto transposedType = RankedTensorType::get(
|
||||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||||
Value transposedInput =
|
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||||
transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
|
||||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||||
result = transposeMaybeInCompute(
|
result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||||
transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(softmaxOp, result);
|
rewriter.replaceOp(softmaxOp, result);
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
Block& oldBlock = compute.getBody().front();
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
rewriter.setInsertionPointAfter(compute);
|
||||||
auto newCompute =
|
auto newCompute = spatial::SpatCompute::create(
|
||||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
for (Value weight : promoted->newWeights) {
|
for (Value weight : promoted->newWeights) {
|
||||||
@@ -155,7 +155,12 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
mapper.map(*oldWeightArg, *newWeightArg);
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
}
|
}
|
||||||
if (failed(mapPromotedInputArguments(
|
if (failed(mapPromotedInputArguments(
|
||||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
compute,
|
||||||
|
*promoted,
|
||||||
|
bodyRewriter,
|
||||||
|
mapper,
|
||||||
|
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||||
|
rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (Operation& op : oldBlock.without_terminator())
|
for (Operation& op : oldBlock.without_terminator())
|
||||||
@@ -199,7 +204,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
|
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
|
||||||
|
+ compute.getNumResults());
|
||||||
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
||||||
newBlockArgTypes.push_back(laneArg->getType());
|
newBlockArgTypes.push_back(laneArg->getType());
|
||||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||||
@@ -239,7 +245,12 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
mapper.map(*oldWeightArg, *newWeightArg);
|
mapper.map(*oldWeightArg, *newWeightArg);
|
||||||
}
|
}
|
||||||
if (failed(mapPromotedInputArguments(
|
if (failed(mapPromotedInputArguments(
|
||||||
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
|
compute,
|
||||||
|
*promoted,
|
||||||
|
bodyRewriter,
|
||||||
|
mapper,
|
||||||
|
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||||
|
rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||||
|
|||||||
@@ -111,7 +111,8 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||||
Value reshaped = materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
|
Value reshaped =
|
||||||
|
materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
|
||||||
rewriter.replaceOp(reshapeOp, reshaped);
|
rewriter.replaceOp(reshapeOp, reshaped);
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
|
|
||||||
if (isCompileTimeComputable(adaptor.getInput())) {
|
if (isCompileTimeComputable(adaptor.getInput())) {
|
||||||
for (int64_t sliceSize : sliceSizes) {
|
for (int64_t sliceSize : sliceSizes) {
|
||||||
outputs.push_back(
|
outputs.push_back(extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
|
||||||
extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
|
|
||||||
offset += sliceSize;
|
offset += sliceSize;
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(splitOp, outputs);
|
rewriter.replaceOp(splitOp, outputs);
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
@@ -104,8 +104,7 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
|||||||
}
|
}
|
||||||
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||||
Value transposed =
|
Value transposed =
|
||||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation)
|
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
|
||||||
.getResult()[0];
|
|
||||||
rewriter.replaceOp(transposeOp, transposed);
|
rewriter.replaceOp(transposeOp, transposed);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,8 @@
|
|||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -334,8 +334,8 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
|||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
|
getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0),
|
||||||
getOrCreateIndexConstant(constantFolder,
|
getOrCreateIndexConstant(
|
||||||
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize) ),
|
constantFolder, deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize)),
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
inputTensor,
|
inputTensor,
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||||
|
|||||||
@@ -1482,7 +1482,8 @@ void appendScalarSendLoop(MaterializerState& state,
|
|||||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
|
||||||
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
|
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
|
||||||
Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
Value upperBound =
|
||||||
|
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||||
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
|
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
|
||||||
|
|
||||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||||
@@ -1577,7 +1578,8 @@ void appendProjectedScalarSendLoop(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
|
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0);
|
||||||
Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
Value upperBound =
|
||||||
|
getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast<int64_t>(channelIds.size()));
|
||||||
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
|
Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1);
|
||||||
|
|
||||||
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {});
|
||||||
@@ -2342,7 +2344,8 @@ FailureOr<Value> insertPackedScalarRunIntoWholeBatch(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
|
Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0);
|
||||||
Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.slots.size()));
|
Value upperBound =
|
||||||
|
getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast<int64_t>(run.slots.size()));
|
||||||
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
|
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1);
|
||||||
|
|
||||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
|
|||||||
@@ -99,8 +99,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
copiedValue =
|
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
||||||
pim::PimMemCopyHostToDevOp::create(
|
|
||||||
rewriter,
|
rewriter,
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
originalType,
|
originalType,
|
||||||
|
|||||||
Reference in New Issue
Block a user