115 lines
3.5 KiB
C++
115 lines
3.5 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include "WeightMaterialization.hpp"
|
|
#include "ShapeTilingUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
bool isWeightLikeComputeOperand(Value value) {
|
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
|
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
|
return false;
|
|
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
|
|
while (auto* definingOp = value.getDefiningOp()) {
|
|
if (!visited.insert(definingOp).second)
|
|
return false;
|
|
if (hasWeightAlways(definingOp))
|
|
return true;
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
|
value = extractSliceOp.getSource();
|
|
continue;
|
|
}
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
|
value = expandShapeOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
|
value = collapseShapeOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
|
value = transposeOp.getData();
|
|
continue;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
|
if (auto mapped = mapper.lookupOrNull(value))
|
|
return cast<Value>(mapped);
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return failure();
|
|
|
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
|
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
|
if (!tensorType || !tensorType.hasStaticShape())
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
|
sizes.reserve(tensorType.getRank());
|
|
for (int64_t dim : tensorType.getShape())
|
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
|
|
auto referencedValue =
|
|
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
|
mapper.map(value, referencedValue.getResult());
|
|
return referencedValue.getResult();
|
|
}
|
|
|
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
return failure();
|
|
|
|
IRMapping localMapper;
|
|
for (Value operand : definingOp->getOperands()) {
|
|
if (auto mapped = mapper.lookupOrNull(operand)) {
|
|
localMapper.map(operand, cast<Value>(mapped));
|
|
continue;
|
|
}
|
|
|
|
if (isWeightLikeComputeOperand(operand)) {
|
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
|
if (failed(clonedOperand))
|
|
return failure();
|
|
localMapper.map(operand, *clonedOperand);
|
|
continue;
|
|
}
|
|
|
|
localMapper.map(operand, operand);
|
|
}
|
|
|
|
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
|
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
mapper.map(oldResult, newResult);
|
|
|
|
auto mapped = mapper.lookupOrNull(value);
|
|
if (!mapped)
|
|
return failure();
|
|
return cast<Value>(mapped);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|