add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "LoopUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
|
||||
auto lower = matchConstantIndexValue(lowerBound);
|
||||
auto upper = matchConstantIndexValue(upperBound);
|
||||
auto stepValue = matchConstantIndexValue(step);
|
||||
if (!lower || !upper || !stepValue)
|
||||
return std::nullopt;
|
||||
if (*stepValue <= 0)
|
||||
return std::nullopt;
|
||||
if (*upper <= *lower)
|
||||
return int64_t {0};
|
||||
return llvm::divideCeil(*upper - *lower, *stepValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
|
||||
if (yieldedValues.size() == initArgs.size())
|
||||
return success();
|
||||
|
||||
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
|
||||
<< " iter args";
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
|
||||
Location loc,
|
||||
Value lowerBound,
|
||||
Value upperBound,
|
||||
Value step,
|
||||
ValueRange initArgs,
|
||||
NormalizedLoopBodyBuilder bodyBuilder) {
|
||||
NormalizedLoopResult result;
|
||||
|
||||
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
|
||||
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
|
||||
if (*tripCount == 0) {
|
||||
llvm::append_range(result.results, initArgs);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (*tripCount == 1) {
|
||||
result.inductionVar = lowerBound;
|
||||
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
|
||||
return failure();
|
||||
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
|
||||
return failure();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
|
||||
result.inductionVar = result.loop.getInductionVar();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Block* body = result.loop.getBody();
|
||||
if (!body->empty())
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
|
||||
yieldOp->erase();
|
||||
builder.setInsertionPointToEnd(body);
|
||||
ValueRange iterArgs = result.loop.getRegionIterArgs();
|
||||
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
|
||||
result.loop.erase();
|
||||
return failure();
|
||||
}
|
||||
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
|
||||
result.loop.erase();
|
||||
return failure();
|
||||
}
|
||||
scf::YieldOp::create(builder, loc, result.results);
|
||||
}
|
||||
builder.setInsertionPointAfter(result.loop);
|
||||
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct NormalizedLoopResult {
|
||||
mlir::Value inductionVar;
|
||||
llvm::SmallVector<mlir::Value, 4> results;
|
||||
mlir::scf::ForOp loop;
|
||||
|
||||
bool wasInlined() const { return !loop; }
|
||||
};
|
||||
|
||||
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
|
||||
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
|
||||
|
||||
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
|
||||
mlir::Location loc,
|
||||
mlir::Value lowerBound,
|
||||
mlir::Value upperBound,
|
||||
mlir::Value step,
|
||||
mlir::ValueRange initArgs,
|
||||
NormalizedLoopBodyBuilder bodyBuilder);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user