This commit is contained in:
@@ -111,39 +111,29 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
|
||||
|
||||
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
||||
switch (predicate) {
|
||||
case mlir::arith::CmpIPredicate::eq:
|
||||
return lhs == rhs;
|
||||
case mlir::arith::CmpIPredicate::ne:
|
||||
return lhs != rhs;
|
||||
case mlir::arith::CmpIPredicate::slt:
|
||||
return lhs < rhs;
|
||||
case mlir::arith::CmpIPredicate::sle:
|
||||
return lhs <= rhs;
|
||||
case mlir::arith::CmpIPredicate::sgt:
|
||||
return lhs > rhs;
|
||||
case mlir::arith::CmpIPredicate::sge:
|
||||
return lhs >= rhs;
|
||||
case mlir::arith::CmpIPredicate::ult:
|
||||
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ule:
|
||||
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ugt:
|
||||
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::uge:
|
||||
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
|
||||
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
|
||||
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
|
||||
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
|
||||
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
|
||||
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
|
||||
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cmpi predicate");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
|
||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
if (!expr.node)
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Constant:
|
||||
return expr.node->constant;
|
||||
case CompiledIndexExprNode::Kind::Symbol: {
|
||||
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
|
||||
case CompiledIndexExprNode::Kind::Symbol: {
|
||||
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
||||
auto iter = knowledge.indexValues.find(value);
|
||||
if (iter != knowledge.indexValues.end())
|
||||
@@ -158,19 +148,16 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
case CompiledIndexExprNode::Kind::CmpI: {
|
||||
case CompiledIndexExprNode::Kind::CmpI: {
|
||||
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Add:
|
||||
return *lhs + *rhs;
|
||||
case CompiledIndexExprNode::Kind::Sub:
|
||||
return *lhs - *rhs;
|
||||
case CompiledIndexExprNode::Kind::Mul:
|
||||
return *lhs * *rhs;
|
||||
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
|
||||
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
|
||||
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
@@ -191,10 +178,8 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
|
||||
return *lhs % *rhs;
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
case CompiledIndexExprNode::Kind::CmpI:
|
||||
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||
default:
|
||||
llvm_unreachable("unexpected binary compiled index kind");
|
||||
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||
default: llvm_unreachable("unexpected binary compiled index kind");
|
||||
}
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Select: {
|
||||
@@ -639,24 +624,21 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||
bool allStatic = true;
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
|
||||
if (allStatic) {
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
@@ -796,8 +778,8 @@ llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge&
|
||||
return evaluateCompiledIndexExpr(*this, knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress>
|
||||
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
|
||||
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const {
|
||||
(void) lane;
|
||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
|
||||
@@ -33,7 +33,8 @@ struct CompiledIndexExpr {
|
||||
std::shared_ptr<CompiledIndexExprNode> node;
|
||||
|
||||
CompiledIndexExpr() = default;
|
||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
|
||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
|
||||
: node(std::move(node)) {}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
@@ -68,8 +69,8 @@ struct CompiledAddressExpr {
|
||||
mlir::Value base;
|
||||
CompiledIndexExpr byteOffset;
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress>
|
||||
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
|
||||
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const;
|
||||
};
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -10,8 +9,7 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t>
|
||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
|
||||
@@ -9,7 +9,6 @@ namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
llvm::SmallVector<int32_t>
|
||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -24,10 +24,9 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
/// Walks a `pim.core`-like body structurally for verification without
|
||||
/// enumerating full loop trip counts. Loop bounds must still be statically
|
||||
/// evaluable so address resolution remains well-defined.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlockStructurally(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)>
|
||||
callback);
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
Reference in New Issue
Block a user