automatic code-reformat
All checks were successful
Validate Operations / validate-operations (push) Successful in 18m22s
All checks were successful
Validate Operations / validate-operations (push) Successful in 18m22s
This commit is contained in:
@@ -175,7 +175,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
llvm::SmallVector<Type> sourceTypes;
|
||||
llvm::SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources){
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
@@ -183,7 +183,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
|
||||
@@ -92,10 +92,8 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> prepareElementwiseOperand(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static FailureOr<Value>
|
||||
prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto valueType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueType || !valueType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
@@ -280,8 +280,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
|
||||
@@ -71,10 +71,8 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value createAverageCompute(Value input,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
||||
@@ -141,7 +139,8 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
Location loc = reduceMeanOp.getLoc();
|
||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
Value reducedKeepdims =
|
||||
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
|
||||
@@ -24,18 +24,16 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildSoftmax(Value input,
|
||||
int64_t softmaxAxis,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
@@ -71,7 +69,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||
@@ -85,14 +84,15 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
auto preTransposeCompute = createSpatCompute<1>(
|
||||
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
auto preTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<Concat>(ctx);
|
||||
}
|
||||
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Concat>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -19,11 +19,8 @@ static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? ax
|
||||
|
||||
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
@@ -110,12 +107,18 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end());
|
||||
Location loc = gatherOp.getLoc();
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {gatherOp.getResult().getType()}, {}, adaptor.getData(), [&](Value data) -> LogicalResult {
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter,
|
||||
loc,
|
||||
TypeRange {gatherOp.getResult().getType()},
|
||||
{},
|
||||
adaptor.getData(),
|
||||
[&](Value data) -> LogicalResult {
|
||||
Value result;
|
||||
if (indicesType.getRank() == 1) {
|
||||
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
|
||||
} else if (indicesType.getRank() == 2) {
|
||||
}
|
||||
else if (indicesType.getRank() == 2) {
|
||||
int64_t rowCount = indicesType.getShape()[0];
|
||||
int64_t rowWidth = indicesType.getShape()[1];
|
||||
SmallVector<Value> rows;
|
||||
@@ -127,9 +130,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
return failure();
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||
}
|
||||
result =
|
||||
rows.size() == 1 ? rows.front() : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||
} else {
|
||||
result = rows.size() == 1
|
||||
? rows.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||
}
|
||||
else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
||||
@@ -15,11 +15,8 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
@@ -67,8 +64,7 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (resizeOp.getMode() != "nearest"
|
||||
|| resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
|
||||
@@ -76,9 +72,10 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result = buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result =
|
||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
|
||||
@@ -12,12 +12,8 @@ namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
int64_t size,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value extractSliceAt(
|
||||
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
@@ -33,9 +29,8 @@ static Value extractSliceAt(Value input,
|
||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXSplitOp splitOp,
|
||||
ONNXSplitOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXSplitOp splitOp, ONNXSplitOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
|
||||
if (!inputType || !inputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
@@ -485,7 +485,7 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
size_t i = 0;
|
||||
for (auto node : nodes) {
|
||||
ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu;
|
||||
if (i++ == nodes.size() - 1){
|
||||
if (i++ == nodes.size() - 1) {
|
||||
ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
|
||||
}
|
||||
|
||||
@@ -43,7 +43,5 @@ bool TaskDCP::hasDescendent(TaskDCP* child) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//TODO fare qualcosa di sensato
|
||||
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) {
|
||||
return orig_weight;
|
||||
}
|
||||
// TODO fare qualcosa di sensato
|
||||
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return orig_weight; }
|
||||
|
||||
@@ -75,11 +75,11 @@ public:
|
||||
alst = val;
|
||||
}
|
||||
bool hasDescendent(TaskDCP* child);
|
||||
int64_t Id() const { return (int64_t)spatWeightedCompute.getAsOpaquePointer(); }
|
||||
int64_t Id() const { return (int64_t) spatWeightedCompute.getAsOpaquePointer(); }
|
||||
|
||||
bool isCP() const { return alst == aest; }
|
||||
bool isScheduled() const { return scheduledCPU.has_value(); }
|
||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute(){return spatWeightedCompute;}
|
||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() { return spatWeightedCompute; }
|
||||
|
||||
friend std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
friend void removeEdge(TaskDCP* parent, TaskDCP* child);
|
||||
|
||||
@@ -71,12 +71,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
auto begin() {
|
||||
return storage.begin();
|
||||
}
|
||||
|
||||
auto end() {
|
||||
return storage.end();
|
||||
}
|
||||
auto begin() { return storage.begin(); }
|
||||
|
||||
auto end() { return storage.end(); }
|
||||
};
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
@@ -50,11 +52,10 @@ inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spat
|
||||
int64_t tot = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody()) {
|
||||
for (auto& inst : region) {
|
||||
for(auto result : inst.getResults()){
|
||||
if(auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
|
||||
for (auto result : inst.getResults())
|
||||
if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
|
||||
tot += onnx_mlir::getSizeInBytes(element);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tot;
|
||||
}
|
||||
|
||||
@@ -243,8 +243,7 @@ void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
RewriteHostSubviewStorePattern,
|
||||
FoldConstantCoreSubviewPattern>(
|
||||
patterns.getContext());
|
||||
FoldConstantCoreSubviewPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user