less affine code and better affine helpers
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
+1
-1
Submodule onnx-mlir updated: 82018d7ce5...eb54c2afc4
@@ -9,6 +9,7 @@ add_pim_library(OMPimCommon
|
|||||||
IR/LoopUtils.cpp
|
IR/LoopUtils.cpp
|
||||||
IR/ShapeUtils.cpp
|
IR/ShapeUtils.cpp
|
||||||
IR/SubviewUtils.cpp
|
IR/SubviewUtils.cpp
|
||||||
|
IR/TensorSliceUtils.cpp
|
||||||
IR/WeightUtils.cpp
|
IR/WeightUtils.cpp
|
||||||
Support/CheckedArithmetic.cpp
|
Support/CheckedArithmetic.cpp
|
||||||
Support/DebugDump.cpp
|
Support/DebugDump.cpp
|
||||||
|
|||||||
@@ -69,6 +69,15 @@ Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t
|
|||||||
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value affineAddConst(RewriterBase& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
|
||||||
|
assert(constantAnchor && "expected a valid constant anchor");
|
||||||
|
if (offset == 0)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||||
assert(constantAnchor && "expected a valid constant anchor");
|
assert(constantAnchor && "expected a valid constant anchor");
|
||||||
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||||
@@ -90,6 +99,34 @@ Value affineFloorDivConst(
|
|||||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value affineAddModConst(
|
||||||
|
RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) {
|
||||||
|
assert(constantAnchor && "expected a valid constant anchor");
|
||||||
|
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||||
|
if (divisor == 1)
|
||||||
|
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||||
|
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||||
|
AffineExpr expr = d0;
|
||||||
|
if (offset != 0)
|
||||||
|
expr = expr + offset;
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, expr % divisor, ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value affineAddFloorDivConst(
|
||||||
|
RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) {
|
||||||
|
assert(constantAnchor && "expected a valid constant anchor");
|
||||||
|
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
|
||||||
|
if (divisor == 1)
|
||||||
|
return offset == 0 ? value : affineAddConst(rewriter, loc, value, offset, constantAnchor);
|
||||||
|
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||||
|
AffineExpr expr = d0;
|
||||||
|
if (offset != 0)
|
||||||
|
expr = expr + offset;
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, expr.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||||
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||||
return constant.getValue();
|
return constant.getValue();
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
|
|||||||
int64_t multiplier,
|
int64_t multiplier,
|
||||||
mlir::Operation* constantAnchor);
|
mlir::Operation* constantAnchor);
|
||||||
|
|
||||||
|
mlir::Value affineAddConst(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value value,
|
||||||
|
int64_t offset,
|
||||||
|
mlir::Operation* constantAnchor);
|
||||||
|
|
||||||
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
|
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::Value value,
|
mlir::Value value,
|
||||||
@@ -41,6 +47,20 @@ mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
|
|||||||
int64_t divisor,
|
int64_t divisor,
|
||||||
mlir::Operation* constantAnchor);
|
mlir::Operation* constantAnchor);
|
||||||
|
|
||||||
|
mlir::Value affineAddModConst(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value value,
|
||||||
|
int64_t offset,
|
||||||
|
int64_t divisor,
|
||||||
|
mlir::Operation* constantAnchor);
|
||||||
|
|
||||||
|
mlir::Value affineAddFloorDivConst(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value value,
|
||||||
|
int64_t offset,
|
||||||
|
int64_t divisor,
|
||||||
|
mlir::Operation* constantAnchor);
|
||||||
|
|
||||||
llvm::FailureOr<int64_t>
|
llvm::FailureOr<int64_t>
|
||||||
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
|
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
|
||||||
|
|
||||||
|
|||||||
@@ -218,6 +218,14 @@ getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr, int64_t
|
|||||||
return permutation;
|
return permutation;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef<int64_t> values) {
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) {
|
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) {
|
||||||
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
return llvm::SmallVector<mlir::OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,6 +101,8 @@ llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation
|
|||||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||||
int64_t rank);
|
int64_t rank);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef<int64_t> values);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value extractAxisSlice(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||||
|
resultShape[axis] = size;
|
||||||
|
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||||
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||||
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
|
return tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Value source,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||||
|
|
||||||
|
bool isIdentitySlice =
|
||||||
|
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||||
|
&& strides.size() == rank;
|
||||||
|
if (isIdentitySlice) {
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||||
|
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||||
|
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||||
|
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||||
|
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim
|
||||||
|
|| *staticStride != 1) {
|
||||||
|
isIdentitySlice = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isIdentitySlice)
|
||||||
|
return source;
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value insertStaticSlice(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
return tensor::InsertSliceOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
source,
|
||||||
|
dest,
|
||||||
|
offsets,
|
||||||
|
getStaticSizes(rewriter, sourceType.getShape()),
|
||||||
|
getUnitStrides(rewriter, sourceType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::Value extractAxisSlice(
|
||||||
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
|
|
||||||
|
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||||
|
|
||||||
|
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::Value dest,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -5,5 +5,6 @@
|
|||||||
#include "MatrixProductLowering.hpp"
|
#include "MatrixProductLowering.hpp"
|
||||||
#include "ShapeTilingUtils.hpp"
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include "WeightMaterialization.hpp"
|
#include "WeightMaterialization.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -77,65 +77,4 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewri
|
|||||||
return slicesPerCore;
|
return slicesPerCore;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value extractAxisSlice(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
|
||||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
|
||||||
resultShape[axis] = size;
|
|
||||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
|
||||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
|
||||||
return tensor::ExtractSliceOp::create(
|
|
||||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
|
||||||
.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
|
||||||
Location loc,
|
|
||||||
Value source,
|
|
||||||
RankedTensorType resultType,
|
|
||||||
ArrayRef<OpFoldResult> offsets,
|
|
||||||
ArrayRef<OpFoldResult> sizes,
|
|
||||||
ArrayRef<OpFoldResult> strides) {
|
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
|
||||||
size_t rank = static_cast<size_t>(sourceType.getRank());
|
|
||||||
|
|
||||||
bool isIdentitySlice =
|
|
||||||
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
|
||||||
&& strides.size() == rank;
|
|
||||||
if (isIdentitySlice) {
|
|
||||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
||||||
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
|
||||||
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
|
||||||
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
|
||||||
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
|
||||||
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
|
||||||
isIdentitySlice = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isIdentitySlice)
|
|
||||||
return source;
|
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value insertStaticSlice(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
|
||||||
return tensor::InsertSliceOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
source,
|
|
||||||
dest,
|
|
||||||
offsets,
|
|
||||||
getStaticSizes(rewriter, sourceType.getShape()),
|
|
||||||
getUnitStrides(rewriter, sourceType.getRank()))
|
|
||||||
.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -28,21 +28,4 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
|||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
mlir::Value extractAxisSlice(
|
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
|
||||||
|
|
||||||
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::Value source,
|
|
||||||
mlir::RankedTensorType resultType,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
|
||||||
|
|
||||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
|
||||||
mlir::Location loc,
|
|
||||||
mlir::Value source,
|
|
||||||
mlir::Value dest,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1184,48 +1184,6 @@ static Value createZeroPaddedTensor(Value value,
|
|||||||
return padOp.getResult();
|
return padOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value affineAddConst(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
|
|
||||||
if (offset == 0)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value affineMulConst(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
|
|
||||||
if (factor == 1)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value affineFloorDivConst(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
|
||||||
assert(divisor > 0 && "expected positive affine floordiv divisor");
|
|
||||||
if (divisor == 1)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value affineModConst(
|
|
||||||
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
|
|
||||||
assert(modulus > 0 && "expected positive affine mod divisor");
|
|
||||||
if (modulus == 1)
|
|
||||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
|
||||||
|
|
||||||
MLIRContext* context = rewriter.getContext();
|
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
||||||
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createConvInputPatch(Value input,
|
static Value createConvInputPatch(Value input,
|
||||||
RankedTensorType patchType,
|
RankedTensorType patchType,
|
||||||
Value batchIndex,
|
Value batchIndex,
|
||||||
@@ -2316,11 +2274,10 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
|||||||
ValueRange {im2colInit},
|
ValueRange {im2colInit},
|
||||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||||
Value im2colAcc = iterArgs.front();
|
Value im2colAcc = iterArgs.front();
|
||||||
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
|
|
||||||
Value batchIndex =
|
Value batchIndex =
|
||||||
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||||
Value batchPatchIndex =
|
Value batchPatchIndex =
|
||||||
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||||
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||||
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||||
Value inputHeightOffset =
|
Value inputHeightOffset =
|
||||||
|
|||||||
@@ -15,22 +15,6 @@ namespace raptor {
|
|||||||
|
|
||||||
} // namespace raptor
|
} // namespace raptor
|
||||||
|
|
||||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
|
||||||
SmallVector<OpFoldResult, 4> attrs;
|
|
||||||
attrs.reserve(values.size());
|
|
||||||
for (int64_t value : values)
|
|
||||||
attrs.push_back(builder.getIndexAttr(value));
|
|
||||||
return attrs;
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
|
||||||
SmallVector<OpFoldResult, 4> strides;
|
|
||||||
strides.reserve(rank);
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim)
|
|
||||||
strides.push_back(builder.getIndexAttr(1));
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LowerFragmentAssemblyBlueprintPattern
|
struct LowerFragmentAssemblyBlueprintPattern
|
||||||
: OpConversionPattern<spatial::SpatBlueprintOp> {
|
: OpConversionPattern<spatial::SpatBlueprintOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|||||||
@@ -33,9 +33,9 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
Reference in New Issue
Block a user