better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
add support for other constant-time arith ops in codegen
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -99,6 +101,33 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
||||
switch (predicate) {
|
||||
case mlir::arith::CmpIPredicate::eq:
|
||||
return lhs == rhs;
|
||||
case mlir::arith::CmpIPredicate::ne:
|
||||
return lhs != rhs;
|
||||
case mlir::arith::CmpIPredicate::slt:
|
||||
return lhs < rhs;
|
||||
case mlir::arith::CmpIPredicate::sle:
|
||||
return lhs <= rhs;
|
||||
case mlir::arith::CmpIPredicate::sgt:
|
||||
return lhs > rhs;
|
||||
case mlir::arith::CmpIPredicate::sge:
|
||||
return lhs >= rhs;
|
||||
case mlir::arith::CmpIPredicate::ult:
|
||||
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ule:
|
||||
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ugt:
|
||||
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::uge:
|
||||
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cmpi predicate");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
@@ -153,6 +182,16 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
@@ -169,6 +208,31 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
}
|
||||
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
|
||||
}
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
|
||||
@@ -8,19 +8,28 @@
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
return mlir::isa<mlir::arith::ConstantOp,
|
||||
mlir::arith::AddIOp,
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
mlir::memref::SubViewOp,
|
||||
mlir::memref::CastOp,
|
||||
mlir::memref::CollapseShapeOp,
|
||||
mlir::memref::ExpandShapeOp>(op);
|
||||
if (mlir::isa<mlir::arith::ConstantOp,
|
||||
mlir::arith::AddIOp,
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::DivSIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::RemSIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::arith::CmpIOp,
|
||||
mlir::memref::AllocOp,
|
||||
mlir::memref::SubViewOp,
|
||||
mlir::memref::CastOp,
|
||||
mlir::memref::CollapseShapeOp,
|
||||
mlir::memref::ExpandShapeOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
|
||||
return selectOp.getType().isIntOrIndex();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -405,6 +405,16 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
emitMemCopyOp("ld",
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||
@@ -825,6 +835,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
|
||||
else if (auto loadBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
coreCodeGen.codeGenLoadBatchOp(loadBatchOp, knowledge);
|
||||
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
||||
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
|
||||
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
||||
|
||||
@@ -143,6 +143,7 @@ public:
|
||||
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
|
||||
+1965
-414
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
@@ -37,6 +38,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
DominanceInfo dominance(coreOp);
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
||||
@@ -70,7 +72,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user