refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
@@ -23,11 +23,11 @@ namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1;
return operandIndex == 3;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0;
return operandIndex == 2;
return false;
}
@@ -39,7 +39,10 @@ static int64_t getValueSizeInBytes(Value value) {
}
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
@@ -48,6 +51,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
});
for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
@@ -105,16 +111,17 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
.getOutput();
}
else {
copiedValue = pim::PimMemCopyHostToDevOp::create(
rewriter,
op->getLoc(),
originalType,
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput();
copiedValue =
pim::PimMemCopyHostToDevOp::create(
rewriter,
op->getLoc(),
originalType,
getOrCreateHostIndexConstant(op, 0, constantFolder),
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput();
}
cachedByType[originalType] = copiedValue;
@@ -134,6 +141,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
@@ -141,10 +149,10 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())