refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+1
View File
@@ -1,5 +1,6 @@
add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
+46
View File
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -126,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure();
}
+62
View File
@@ -0,0 +1,62 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
}
} // namespace onnx_mlir
+28
View File
@@ -0,0 +1,28 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
} // namespace onnx_mlir
+4
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
@@ -30,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
+25 -16
View File
@@ -21,12 +21,13 @@ namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
found |= mvmOp.getWeight() == weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
found |= vmmOp.getWeight() == weightArg;
});
return found;
}
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreOp->getOpOperand(weightIndex));
break;
}
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreBatchOp->getOpOperand(weightIndex));
break;
}
});
});
}
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"