add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
This commit is contained in:
@@ -121,6 +121,13 @@ json::Object PimCodeGen::createEmptyOffset() {
|
||||
return offset;
|
||||
}
|
||||
|
||||
static json::Object createRs1OnlyOffset() {
|
||||
json::Object offset;
|
||||
offset["offset_select"] = 1;
|
||||
offset["offset_value"] = 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
void PimCodeGen::emitInstruction(json::Object instruction) const {
|
||||
coreFileStream << json::Value(std::move(instruction)) << ',';
|
||||
}
|
||||
@@ -331,7 +338,8 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||
json["op"] = "vavg";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["rs2"] = 1;
|
||||
json["offset"] = createRs1OnlyOffset();
|
||||
json["len"] = getValueSizeInBytes(vavgOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
@@ -4,10 +4,13 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/Math/ReduceMean.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/Relu.cpp
|
||||
Patterns/NN/Sigmoid.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#define DEFINE_MAP_OP(opname) opname,
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
|
||||
@@ -72,11 +72,15 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
target.addIllegalOp<ONNXDivOp>();
|
||||
target.addIllegalOp<ONNXMulOp>();
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
|
||||
target.addIllegalOp<ONNXAveragePoolOp>();
|
||||
target.addIllegalOp<ONNXReluOp>();
|
||||
target.addIllegalOp<ONNXSigmoidOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
@@ -86,10 +90,13 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
|
||||
|
||||
@@ -7,14 +7,20 @@ namespace onnx_mlir {
|
||||
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
204
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp
Normal file
204
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
||||
strides[i] = strides[i + 1] * shape[i + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto denseAttr = getDenseConstantAttr(value);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType)
|
||||
return value;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
ArrayRef<int64_t> resultShape = resultType.getShape();
|
||||
if (sourceShape.size() > resultShape.size())
|
||||
return failure();
|
||||
|
||||
const int64_t rankOffset = static_cast<int64_t>(resultShape.size() - sourceShape.size());
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
|
||||
const int64_t sourceIndex = i - rankOffset;
|
||||
const int64_t sourceDim = sourceIndex < 0 ? 1 : sourceShape[sourceIndex];
|
||||
const int64_t resultDim = resultShape[i];
|
||||
if (sourceDim != 1 && sourceDim != resultDim)
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceShape);
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultShape);
|
||||
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) {
|
||||
int64_t remaining = flatIndex;
|
||||
int64_t sourceFlatIndex = 0;
|
||||
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[i];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[i];
|
||||
|
||||
const int64_t sourceIndex = i - rankOffset;
|
||||
if (sourceIndex < 0)
|
||||
continue;
|
||||
|
||||
const int64_t sourceDim = sourceShape[sourceIndex];
|
||||
const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndex;
|
||||
sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex];
|
||||
}
|
||||
|
||||
resultValues.push_back(sourceValues[sourceFlatIndex]);
|
||||
}
|
||||
|
||||
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> prepareElementwiseOperand(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto valueType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueType || !valueType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (valueType == resultType)
|
||||
return value;
|
||||
|
||||
return materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeReciprocalTensor(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto broadcastedValue = materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
|
||||
if (failed(broadcastedValue))
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(getDenseConstantAttr(*broadcastedValue));
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> reciprocalValues;
|
||||
reciprocalValues.reserve(denseAttr.getNumElements());
|
||||
for (const APFloat& valueAttr : denseAttr.getValues<APFloat>()) {
|
||||
APFloat reciprocal(valueAttr.getSemantics(), 1);
|
||||
auto status = reciprocal.divide(valueAttr, APFloat::rmNearestTiesToEven);
|
||||
if (status & APFloat::opInvalidOp)
|
||||
return failure();
|
||||
reciprocalValues.push_back(std::move(reciprocal));
|
||||
}
|
||||
|
||||
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
|
||||
}
|
||||
|
||||
template <typename OnnxOp, typename SpatialOp>
|
||||
struct BinaryElementwiseToSpatialCompute : OpConversionPattern<OnnxOp> {
|
||||
using OpConversionPattern<OnnxOp>::OpConversionPattern;
|
||||
using Adaptor = typename OnnxOp::Adaptor;
|
||||
|
||||
LogicalResult matchAndRewrite(OnnxOp op, Adaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto lhs = prepareElementwiseOperand(adaptor.getOperands()[0], resultType, rewriter, loc);
|
||||
if (failed(lhs))
|
||||
return failure();
|
||||
|
||||
auto rhs = prepareElementwiseOperand(adaptor.getOperands()[1], resultType, rewriter, loc);
|
||||
if (failed(rhs))
|
||||
return failure();
|
||||
|
||||
constexpr size_t numInputs = 2;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {*lhs, *rhs}, [&](Value x, Value y) {
|
||||
auto loweredOp = SpatialOp::create(rewriter, loc, resultType, x, y);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loweredOp.getResult());
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXDivOp op, ONNXDivOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto lhs = prepareElementwiseOperand(adaptor.getA(), resultType, rewriter, loc);
|
||||
if (failed(lhs))
|
||||
return failure();
|
||||
|
||||
auto reciprocalRhs = materializeReciprocalTensor(adaptor.getB(), resultType, rewriter, loc);
|
||||
if (failed(reciprocalRhs))
|
||||
return failure();
|
||||
|
||||
constexpr size_t numInputs = 2;
|
||||
auto computeOp = createSpatCompute<numInputs>(
|
||||
rewriter, loc, resultType, {}, ValueRange {*lhs, *reciprocalRhs}, [&](Value x, Value reciprocal) {
|
||||
auto mulOp = spatial::SpatVMulOp::create(rewriter, loc, resultType, x, reciprocal);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, mulOp.getResult());
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
|
||||
patterns.add<DivToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
163
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp
Normal file
163
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp
Normal file
@@ -0,0 +1,163 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; axis++)
|
||||
normalizedAxes.push_back(axis);
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
normalizedAxes.reserve(axesAttr.size());
|
||||
for (Attribute attr : axesAttr) {
|
||||
int64_t axis = cast<IntegerAttr>(attr).getInt();
|
||||
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
|
||||
}
|
||||
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<bool> reducedAxes(rank, false);
|
||||
for (int64_t axis : axes) {
|
||||
if (axis < 0 || axis >= rank)
|
||||
return {};
|
||||
reducedAxes[axis] = true;
|
||||
}
|
||||
return reducedAxes;
|
||||
}
|
||||
|
||||
static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementType) {
|
||||
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
ReassociationIndices currentGroup;
|
||||
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
currentGroup.push_back(axis);
|
||||
if (!isReduced) {
|
||||
reassociation.push_back(currentGroup);
|
||||
currentGroup.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!currentGroup.empty()) {
|
||||
if (reassociation.empty())
|
||||
reassociation.push_back(std::move(currentGroup));
|
||||
else
|
||||
reassociation.back().append(currentGroup.begin(), currentGroup.end());
|
||||
}
|
||||
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value createAverageCompute(Value input,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
if (axis == rank)
|
||||
return createAverageCompute(input, leafType, rewriter, loc);
|
||||
|
||||
if (reducedAxes[axis])
|
||||
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> reducedSlices;
|
||||
reducedSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (resultType.getRank() == 0) {
|
||||
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
|
||||
arith::ConstantIndexOp::create(rewriter, loc, 0));
|
||||
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
return tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
|
||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||
return failure();
|
||||
|
||||
Location loc = reduceMeanOp.getLoc();
|
||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value reduced = squeezeReducedAxes(reducedKeepdims, resultType, reducedAxes, rewriter, loc);
|
||||
rewriter.replaceOp(reduceMeanOp, reduced);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<ReduceMeanToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
36
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp
Normal file
36
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Sigmoid.cpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct SigmoidToSpatialCompute : OpConversionPattern<ONNXSigmoidOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXSigmoidOp sigmoidOp,
|
||||
ONNXSigmoidOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
Location loc = sigmoidOp.getLoc();
|
||||
Type resultType = sigmoidOp.getResult().getType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||
auto spatSigmoidOp = spatial::SpatSigmoidOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, spatSigmoidOp.getResult());
|
||||
});
|
||||
rewriter.replaceOp(sigmoidOp, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSigmoidPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<SigmoidToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -39,6 +39,12 @@ def spatToPimVVMul : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVAvg : Pat<
|
||||
(SpatVAvgOp:$srcOpRes $input),
|
||||
(PimVAvgOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMax : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
@@ -51,4 +57,10 @@ def spatToPimVRelu : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVSigm : Pat<
|
||||
(SpatSigmoidOp:$srcOpRes $input),
|
||||
(PimVSigmOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -161,26 +161,41 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
|
||||
operationsToRemove.push_back(receiveOp);
|
||||
markOpToRemove(receiveOp);
|
||||
runOnReceiveOp(receiveOp, rewriter);
|
||||
}
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
|
||||
operationsToRemove.push_back(computeOp);
|
||||
markOpToRemove(computeOp);
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnOpOperands(returnOp, rewriter);
|
||||
|
||||
// Remove all ComputeOps
|
||||
for (auto opToRemove : llvm::reverse(operationsToRemove)) {
|
||||
if (!opToRemove->use_empty()) {
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
while (!pendingRemovals.empty()) {
|
||||
bool erasedAnyOp = false;
|
||||
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
||||
Operation* opToRemove = *it;
|
||||
if (!opToRemove->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(opToRemove);
|
||||
it = pendingRemovals.erase(it);
|
||||
erasedAnyOp = true;
|
||||
}
|
||||
|
||||
if (erasedAnyOp)
|
||||
continue;
|
||||
|
||||
for (auto opToRemove : pendingRemovals) {
|
||||
opToRemove->dump();
|
||||
for (auto user : opToRemove->getUsers())
|
||||
user->dump();
|
||||
assert(false && "opToRemove should be unused at this point");
|
||||
}
|
||||
rewriter.eraseOp(opToRemove);
|
||||
assert(false && "tracked op removal reached a cycle or missed dependency");
|
||||
}
|
||||
|
||||
// Dump to file for debug
|
||||
@@ -284,10 +299,19 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
auto concatUses = concatValue.getUses();
|
||||
auto numConcatUses = rangeLength(concatUses);
|
||||
if (numConcatUses == 1) {
|
||||
OpOperand& concatUse = *concatUses.begin();
|
||||
Operation* concatUser = concatUse.getOwner();
|
||||
Value chainedValue = concatValue;
|
||||
Operation* concatUser = concatUses.begin()->getOwner();
|
||||
|
||||
while (isChannelUseChainOp(concatUser)) {
|
||||
auto chainUses = concatUser->getResult(0).getUses();
|
||||
if (rangeLength(chainUses) != 1)
|
||||
break;
|
||||
chainedValue = concatUser->getResult(0);
|
||||
concatUser = chainUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(concatUser)) {
|
||||
size_t concatIndexInReturn = concatUse.getOperandNumber();
|
||||
size_t concatIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
|
||||
size_t offset = 0;
|
||||
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
|
||||
@@ -602,10 +626,22 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
rewriter.modifyOpInPlace(returnOp,
|
||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||
|
||||
if (isa<tensor::ConcatOp>(returnOperand)) {
|
||||
auto returnOperandUses = it.value().getUses();
|
||||
if (rangeLength(returnOperandUses) == 0)
|
||||
rewriter.eraseOp(returnOperand);
|
||||
Operation* opToErase = returnOperand;
|
||||
while (opToErase) {
|
||||
bool isExclusivelyOwnedByReturnChain = opToErase->use_empty() || opToErase->hasOneUse();
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
break;
|
||||
|
||||
if (isChannelUseChainOp(opToErase)) {
|
||||
Value source = opToErase->getOperand(0);
|
||||
markOpToRemove(opToErase);
|
||||
opToErase = source.getDefiningOp();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<tensor::ConcatOp>(opToErase))
|
||||
markOpToRemove(opToErase);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,6 +239,22 @@ def SpatSumOp : SpatOp<"sum", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAvgOp : SpatOp<"vavg", []> {
|
||||
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
let summary = "Element-wise sigmoid activation";
|
||||
|
||||
|
||||
@@ -361,7 +361,7 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.load using by creating a new global buffer
|
||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -370,8 +370,21 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId) {
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
||||
.getOutput();
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
@@ -379,31 +392,30 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
||||
if (!sendOp)
|
||||
continue;
|
||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
||||
}
|
||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
||||
return failure();
|
||||
}();
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
bufferAllocation,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -428,8 +440,7 @@ struct ChannelBroadcastSendOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send into a device-to-host copy into the shared
|
||||
* broadcast buffer that receive ops load from later.
|
||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -448,32 +459,32 @@ struct ChannelBroadcastSendOpInterface
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
bufferAllocation.getType(),
|
||||
bufferAllocation,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
rewriter.getI32IntegerAttr(dstCoreId));
|
||||
}
|
||||
|
||||
if (!foundReceiver) {
|
||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user