Dynamic gemm/conv
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -12,6 +16,72 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (auto attr = dyn_cast<Attribute>(result))
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
|
||||
return cast<Value>(result);
|
||||
}
|
||||
|
||||
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
APInt lhsConst;
|
||||
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
|
||||
return rhs;
|
||||
|
||||
APInt rhsConst;
|
||||
if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero())
|
||||
return lhs;
|
||||
|
||||
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
|
||||
}
|
||||
|
||||
static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
APInt factorConst;
|
||||
if (auto attr = dyn_cast<Attribute>(factor))
|
||||
factorConst = cast<IntegerAttr>(attr).getValue();
|
||||
else if (!matchPattern(cast<Value>(factor), m_ConstantInt(&factorConst)))
|
||||
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
|
||||
|
||||
if (factorConst.isZero())
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
if (factorConst.isOne())
|
||||
return value;
|
||||
|
||||
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult();
|
||||
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
||||
}
|
||||
|
||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
for (OpFoldResult stride : strides) {
|
||||
APInt strideValue;
|
||||
if (auto attr = dyn_cast<Attribute>(stride)) {
|
||||
if (cast<IntegerAttr>(attr).getInt() != 1)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (!matchPattern(cast<Value>(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne())
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()),
|
||||
llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend()));
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
if (firstDifferentSize == sizesAndShape.end())
|
||||
return true;
|
||||
|
||||
++firstDifferentSize;
|
||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size == 1;
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
@@ -123,4 +193,87 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
Value materializeContiguousTensorSlice(Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> strides,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
assert(resultType.hasStaticShape() && "expected static result type");
|
||||
size_t rank = static_cast<size_t>(resultType.getRank());
|
||||
assert(offsets.size() == rank && "expected rank-matching offsets");
|
||||
assert(strides.size() == rank && "expected rank-matching strides");
|
||||
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(resultType.getRank());
|
||||
for (int64_t size : resultType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(size));
|
||||
|
||||
if (isContiguousTensorSlice(source, resultType, strides))
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
|
||||
if (resultType.getRank() == 0)
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
|
||||
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||
SmallVector<Value> zeroIndices(resultType.getRank());
|
||||
for (Value& zeroIndex : zeroIndices)
|
||||
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
|
||||
SmallVector<Value> resultIndices;
|
||||
resultIndices.reserve(resultType.getRank());
|
||||
|
||||
auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value {
|
||||
if (dim == resultType.getRank()) {
|
||||
SmallVector<Value> sourceIndices;
|
||||
sourceIndices.reserve(resultType.getRank());
|
||||
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
|
||||
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc);
|
||||
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
|
||||
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> sourceOffsets;
|
||||
SmallVector<OpFoldResult> destinationOffsets;
|
||||
SmallVector<OpFoldResult> unitSizes;
|
||||
SmallVector<OpFoldResult> unitStrides;
|
||||
sourceOffsets.reserve(resultType.getRank());
|
||||
destinationOffsets.reserve(resultType.getRank());
|
||||
unitSizes.reserve(resultType.getRank());
|
||||
unitStrides.reserve(resultType.getRank());
|
||||
for (Value index : sourceIndices)
|
||||
sourceOffsets.push_back(index);
|
||||
for (Value index : resultIndices)
|
||||
destinationOffsets.push_back(index);
|
||||
for (int64_t idx = 0; idx < resultType.getRank(); ++idx) {
|
||||
unitSizes.push_back(rewriter.getIndexAttr(1));
|
||||
unitStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
auto elementTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(resultType.getRank(), 1), resultType.getElementType());
|
||||
Value elementSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides)
|
||||
.getResult();
|
||||
return tensor::InsertSliceOp::create(
|
||||
rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value lower = zeroIndices[dim];
|
||||
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult();
|
||||
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult();
|
||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
resultIndices.push_back(loop.getInductionVar());
|
||||
Value updated = self(self, dim + 1, loop.getRegionIterArgs().front());
|
||||
resultIndices.pop_back();
|
||||
scf::YieldOp::create(rewriter, loc, updated);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
};
|
||||
|
||||
return buildLoopNest(buildLoopNest, 0, init);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user