faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-25 15:24:12 +02:00
parent 4855a2e105
commit e8a08f6dd0
18 changed files with 1610 additions and 573 deletions
+1
View File
@@ -1,5 +1,6 @@
add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
+431
View File
@@ -32,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
return value;
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -128,6 +136,225 @@ static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t l
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant:
return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add:
return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub:
return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul:
return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI:
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default:
llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -353,6 +580,191 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
}
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
if (allStatic) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
constantByteOffset +=
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
@@ -361,6 +773,8 @@ llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueK
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
@@ -374,4 +788,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress>
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir
+52
View File
@@ -1,10 +1,14 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
@@ -23,6 +27,51 @@ struct StaticValueKnowledge {
StaticValueKnowledge() {}
};
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress>
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
@@ -35,9 +84,12 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir
+22
View File
@@ -0,0 +1,22 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
llvm::SmallVector<int32_t>
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
} // namespace onnx_mlir
+15
View File
@@ -0,0 +1,15 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t>
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
} // namespace onnx_mlir
+56
View File
@@ -2,6 +2,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -78,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
return mlir::success(!hasFailure);
}
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
hasFailure = true;
continue;
}
if (*step <= 0) {
forOp.emitOpError("requires positive scf.for step for PIM verification");
hasFailure = true;
continue;
}
llvm::SmallVector<int64_t, 2> samples;
if (*lowerBound < *upperBound) {
samples.push_back(*lowerBound);
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
if (last != *lowerBound)
samples.push_back(last);
}
for (int64_t inductionValue : samples) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
hasFailure = true;
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+9
View File
@@ -21,4 +21,13 @@ walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
/// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined.
mlir::LogicalResult
walkPimCoreBlockStructurally(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)>
callback);
} // namespace onnx_mlir
+18
View File
@@ -117,4 +117,22 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
});
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
}
return std::nullopt;
}
} // namespace onnx_mlir
+26
View File
@@ -3,9 +3,15 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
@@ -26,4 +32,24 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
auto addWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
} // namespace onnx_mlir
-1
View File
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp
PimWeightEmitter.cpp
+1 -1
View File
@@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
-193
View File
@@ -1,193 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
if (Value mapped = mapper.lookupOrNull(value))
return mapped;
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
}
Operation* definingOp = value.getDefiningOp();
assert(definingOp && "expected captured value to be defined by an operation");
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
for (Value operand : definingOp->getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
Operation* cloned = builder.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front();
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size();
IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
continue;
}
size_t inputIndex = argIndex - 1 - weightCount;
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
}
for (Operation& op : oldBlock) {
if (isa<pim::PimHaltOp>(op))
continue;
for (Value operand : op.getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
builder,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
ArrayRef<unsigned> lanes,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
assert(!lanes.empty() && "expected at least one batch lane");
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
OperationFolder constantFolder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
auto coreIds = getBatchCoreIds(coreBatchOp);
int32_t coreId = coreIds[lanes.front()];
for (unsigned lane : lanes)
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
SmallVector<Type> weightTypes;
SmallVector<Location> weightLocs;
weightTypes.reserve(weights.size());
weightLocs.reserve(weights.size());
for (Value weight : weights) {
weightTypes.push_back(weight.getType());
weightLocs.push_back(weight.getLoc());
}
Block* block =
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
builder.setInsertionPointToEnd(block);
for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
return callback(scalarCore);
}
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
}
} // namespace onnx_mlir
-16
View File
@@ -1,16 +0,0 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
llvm::ArrayRef<unsigned> lanes,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
+70 -11
View File
@@ -4,13 +4,16 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <limits>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
@@ -23,6 +26,13 @@ struct MemEntry {
size_t size;
};
struct MemoryValueKey {
mlir::Value value;
std::optional<unsigned> lane;
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
@@ -50,33 +60,33 @@ struct MemoryReportEntry {
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value);
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry);
public:
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
MemEntry getMemEntry(const MemoryValueKey& key) const;
};
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
@@ -84,14 +94,21 @@ private:
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries), hostMem(memEntriesMap), fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
size_t getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {},
std::optional<unsigned> lane = std::nullopt) const;
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
@@ -103,15 +120,24 @@ public:
void clean(mlir::Operation* op);
};
struct CoreEmissionJob {
mlir::Operation* coreLikeOp = nullptr;
size_t originalCoreId = 0;
size_t emittedCoreId = 0;
llvm::SmallVector<unsigned, 4> lanes;
std::optional<uint64_t> batchReportId;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
return memory.getValueAddress(value, knowledge, batchLane);
}
size_t remapCoreId(size_t coreId) const;
@@ -141,6 +167,10 @@ public:
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getIndexValue(value, knowledge);
}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
@@ -151,6 +181,14 @@ public:
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp,
llvm::ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const;
void codeGenSendBatchOp(pim::PimSendBatchOp sendOp, unsigned lane, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp,
llvm::ArrayRef<int32_t> laneCoreIds,
const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
@@ -173,3 +211,24 @@ public:
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() {
return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0};
}
static onnx_mlir::MemoryValueKey getTombstoneKey() {
return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0};
}
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
}
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
};
} // namespace llvm
+52 -43
View File
@@ -3,17 +3,19 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <type_traits>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
@@ -126,30 +128,6 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
@@ -171,34 +149,36 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
llvm::DenseMap<mlir::Value, std::string> mapWeightValueToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto processWeight = [&](Operation* ownerOp,
mlir::Value weight,
size_t weightIndex,
size_t coreId) -> LogicalResult {
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
return success();
if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) {
mapCoreWeightToFileName[coreId].insert({weight, weightFile->second});
return success();
}
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapWeightValueToFileName[weight] = fileName;
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
return success();
}
DenseElementsAttr denseAttr = weightView->denseAttr;
@@ -237,20 +217,49 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapWeightValueToFileName[weight] = newFileName;
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
auto processCoreLike = [&](auto coreLikeOp) {
auto usedIndices = getUsedWeightIndices(coreLikeOp);
for (unsigned index : usedIndices) {
if (index >= coreLikeOp.getWeights().size()) {
coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range");
}
}
if constexpr (std::is_same_v<std::decay_t<decltype(coreLikeOp)>, pim::PimCoreOp>) {
size_t coreId = static_cast<size_t>(coreLikeOp.getCoreId());
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
else {
auto batchCoreIds = getBatchCoreIds(coreLikeOp);
SmallVector<size_t> orderedCoreIds;
llvm::SmallSet<size_t, 8> seenCoreIds;
for (int32_t coreId : batchCoreIds)
if (seenCoreIds.insert(static_cast<size_t>(coreId)).second)
orderedCoreIds.push_back(static_cast<size_t>(coreId));
for (size_t coreId : orderedCoreIds)
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
(void) processCoreLike(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
(void) processCoreLike(cast<pim::PimCoreBatchOp>(op));
}
return mapCoreWeightToFileName;
}
@@ -8,6 +8,7 @@
#include <fstream>
#include "Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
@@ -47,12 +48,6 @@ struct CoalescingReportEntry {
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)},
+124 -10
View File
@@ -6,10 +6,10 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -173,6 +173,106 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op);
}
static FailureOr<ShapedType> getStaticByteSizedShapedType(Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return failure();
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return failure();
return shapedType;
}
static LogicalResult verifyBatchOpSemantics(Operation& op,
const StaticValueKnowledge& knowledge,
pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
auto reportFailure = [&](auto emitDiagnostic) {
diagnostics.report(&op, [&](Operation* illegalOp) { emitDiagnostic(illegalOp); });
hasFailure = true;
};
if (auto memcpHdBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
if (!isCodegenAddressableValue(memcpHdBatchOp.getHostSource(), knowledge)) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("host operand #1 is not backed by contiguous addressable storage");
});
}
return success(!hasFailure);
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
if (sendBatchOp.getTargetCoreIds().size() != static_cast<size_t>(sendBatchOp->getParentOfType<pim::PimCoreBatchOp>()
.getLaneCount())) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("targetCoreIds size must match parent laneCount");
});
}
return success(!hasFailure);
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
if (receiveBatchOp.getSourceCoreIds().size()
!= static_cast<size_t>(receiveBatchOp->getParentOfType<pim::PimCoreBatchOp>().getLaneCount())) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("sourceCoreIds size must match parent laneCount");
});
}
return success(!hasFailure);
}
auto verifyTensorBatchCommunication = [&](Value tensorValue, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty()) {
reportFailure([&](Operation* illegalOp) { illegalOp->emitOpError() << kind << " must carry at least one chunk"; });
return;
}
auto parentBatchOp = op.getParentOfType<pim::PimCoreBatchOp>();
int32_t laneCount = parentBatchOp.getLaneCount();
if (laneCount <= 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " requires a positive parent laneCount";
});
return;
}
if (coreIds.size() % static_cast<size_t>(laneCount) != 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " core id count must be divisible by the parent laneCount";
});
return;
}
auto shapedType = getStaticByteSizedShapedType(tensorValue.getType());
if (failed(shapedType)) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " requires a static shaped tensor or memref with byte-sized elements";
});
return;
}
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = (*shapedType).getNumElements() * (*shapedType).getElementTypeBitWidth() / 8;
if (totalBytes % chunkCount != 0) {
reportFailure([&](Operation* illegalOp) {
illegalOp->emitOpError() << kind << " tensor byte size must be divisible by the chunk count per lane";
});
}
};
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op))
verifyTensorBatchCommunication(sendTensorBatchOp.getInput(),
sendTensorBatchOp.getTargetCoreIds(),
"send_tensor_batch");
else if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op))
verifyTensorBatchCommunication(receiveTensorBatchOp.getOutput(),
receiveTensorBatchOp.getSourceCoreIds(),
"receive_tensor_batch");
return success(!hasFailure);
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -204,16 +304,24 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
(void) verifyCoreOperands(coreOp, diagnostics);
StaticValueKnowledge knowledge;
(void) verifyCoreLikeOperands(coreOp, knowledge, diagnostics);
continue;
}
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
(void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
return verifyCoreOperands(scalarCore, diagnostics);
});
llvm::SmallVector<unsigned, 2> lanes;
lanes.push_back(0);
if (coreBatchOp.getLaneCount() > 1)
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
for (unsigned lane : lanes) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
(void) verifyCoreLikeOperands(coreBatchOp, knowledge, diagnostics);
}
continue;
}
@@ -299,10 +407,13 @@ private:
return success(!hasFailure);
}
template <typename CoreOpTy>
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
template <typename CoreLikeOpTy>
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
const StaticValueKnowledge& initialKnowledge,
pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlockStructurally(coreLikeOp.getBody().front(),
initialKnowledge,
[&](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false;
if (!isSupportedCoreInstructionOp(&op)) {
diagnostics.report(&op, [](Operation* illegalOp) {
@@ -370,6 +481,9 @@ private:
hasFailure = true;
}
}
if (failed(verifyBatchOpSemantics(op, knowledge, diagnostics)))
hasFailure = true;
return success(!hasFailure);
});
}