better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Waiting to run

add support for other constant-time arith ops in codegen
This commit is contained in:
NiccoloN
2026-05-24 10:10:24 +02:00
parent b79ce8eeaa
commit c734f1b37e
6 changed files with 2067 additions and 428 deletions
+64
View File
@@ -4,6 +4,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include <limits>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -99,6 +101,33 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq:
return lhs == rhs;
case mlir::arith::CmpIPredicate::ne:
return lhs != rhs;
case mlir::arith::CmpIPredicate::slt:
return lhs < rhs;
case mlir::arith::CmpIPredicate::sle:
return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt:
return lhs > rhs;
case mlir::arith::CmpIPredicate::sge:
return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult:
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule:
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt:
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge:
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -153,6 +182,16 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return mlir::failure();
return *lhs / *rhs;
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
@@ -169,6 +208,31 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
}
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
+11 -2
View File
@@ -8,19 +8,28 @@
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
if (mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::DivSIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::RemSIOp,
mlir::arith::IndexCastOp,
mlir::arith::CmpIOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
mlir::memref::ExpandShapeOp>(op))
return true;
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
return selectOp.getType().isIntOrIndex();
return false;
}
mlir::LogicalResult
+12
View File
@@ -405,6 +405,16 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa
loadOp.getSize());
}
void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp,
const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
@@ -825,6 +835,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
else if (auto loadBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op))
coreCodeGen.codeGenLoadBatchOp(loadBatchOp, knowledge);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
+1
View File
@@ -143,6 +143,7 @@ public:
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
File diff suppressed because it is too large Load Diff
@@ -2,6 +2,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -37,6 +38,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
DominanceInfo dominance(coreOp);
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
@@ -70,7 +72,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end()) {
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
operand.set(cachedValue->second);
continue;
}