faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -32,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
|
||||
return value;
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
||||
|
||||
template <typename... Args>
|
||||
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
@@ -128,6 +136,225 @@ static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t l
|
||||
llvm_unreachable("unknown cmpi predicate");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
|
||||
if (!expr.node)
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Constant:
|
||||
return expr.node->constant;
|
||||
case CompiledIndexExprNode::Kind::Symbol: {
|
||||
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
||||
auto iter = knowledge.indexValues.find(value);
|
||||
if (iter != knowledge.indexValues.end())
|
||||
return iter->second;
|
||||
return mlir::failure();
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Add:
|
||||
case CompiledIndexExprNode::Kind::Sub:
|
||||
case CompiledIndexExprNode::Kind::Mul:
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
case CompiledIndexExprNode::Kind::CmpI: {
|
||||
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Add:
|
||||
return *lhs + *rhs;
|
||||
case CompiledIndexExprNode::Kind::Sub:
|
||||
return *lhs - *rhs;
|
||||
case CompiledIndexExprNode::Kind::Mul:
|
||||
return *lhs * *rhs;
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
case CompiledIndexExprNode::Kind::CmpI:
|
||||
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||
default:
|
||||
llvm_unreachable("unexpected binary compiled index kind");
|
||||
}
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Select: {
|
||||
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
|
||||
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
|
||||
if (!denseAttr || !globalType)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(expr.node->operands.size());
|
||||
for (const CompiledIndexExpr& operand : expr.node->operands) {
|
||||
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown compiled index kind");
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
|
||||
expr.globalOp = globalOp;
|
||||
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
|
||||
expr.operands.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto compiledIndex = compileIndexValueImpl(index);
|
||||
if (failed(compiledIndex))
|
||||
return mlir::failure();
|
||||
expr.operands.push_back(*compiledIndex);
|
||||
}
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = integerAttr.getInt();
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
|
||||
auto lhs = compileIndexValueImpl(lhsValue);
|
||||
auto rhs = compileIndexValueImpl(rhsValue);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {*lhs, *rhs};
|
||||
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
|
||||
};
|
||||
|
||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||
return compileIndexValueImpl(indexCastOp.getIn());
|
||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
|
||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
|
||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
|
||||
if (failed(expr))
|
||||
return mlir::failure();
|
||||
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
|
||||
exprNode->predicate = cmpOp.getPredicate();
|
||||
return CompiledIndexExpr(exprNode);
|
||||
}
|
||||
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
|
||||
auto lhs = compileIndexValueImpl(maxOp.getLhs());
|
||||
auto rhs = compileIndexValueImpl(maxOp.getRhs());
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode cmpExpr;
|
||||
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
|
||||
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
|
||||
cmpExpr.operands = {*lhs, *rhs};
|
||||
|
||||
CompiledIndexExprNode selectExpr;
|
||||
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
|
||||
return makeCompiledIndexExpr(std::move(selectExpr));
|
||||
}
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = compileIndexValueImpl(selectOp.getCondition());
|
||||
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
|
||||
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
|
||||
if (failed(condition) || failed(trueValue) || failed(falseValue))
|
||||
return mlir::failure();
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
expr.operands = {*condition, *trueValue, *falseValue};
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return compileConstantGlobalLoad(loadOp);
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
@@ -353,6 +580,191 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
}
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
|
||||
int64_t constantByteOffset = 0;
|
||||
CompiledIndexExpr byteOffsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return mlir::failure();
|
||||
value = tiedOperand->get();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (!result)
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
llvm::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
||||
llvm::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||
bool allStatic = true;
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
allStatic = false;
|
||||
}
|
||||
|
||||
if (allStatic) {
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
constantByteOffset +=
|
||||
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
CompiledIndexExpr offsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
|
||||
CompiledIndexExpr operandExpr;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
|
||||
* getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
else {
|
||||
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr scaleExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
scaleExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Mul;
|
||||
expr.operands = {*compiledOffset, scaleExpr};
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {offsetExpr, operandExpr};
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, offsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
constantByteOffset = 0;
|
||||
}
|
||||
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
|
||||
if (constantByteOffset != 0) {
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
|
||||
byteOffsetExpr = constantExpr;
|
||||
else {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, byteOffsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
@@ -361,6 +773,8 @@ llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueK
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
@@ -374,4 +788,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
|
||||
return compileContiguousAddressExprImpl(value);
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
|
||||
return evaluateCompiledIndexExpr(*this, knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress>
|
||||
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
|
||||
(void) lane;
|
||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
return ResolvedContiguousAddress {base, *resolvedOffset};
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Describes a value as a base addressable object plus a statically known
|
||||
@@ -23,6 +27,51 @@ struct StaticValueKnowledge {
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode;
|
||||
|
||||
struct CompiledIndexExpr {
|
||||
std::shared_ptr<CompiledIndexExprNode> node;
|
||||
|
||||
CompiledIndexExpr() = default;
|
||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode {
|
||||
enum class Kind {
|
||||
Constant,
|
||||
Symbol,
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
DivUI,
|
||||
DivSI,
|
||||
RemUI,
|
||||
RemSI,
|
||||
MinUI,
|
||||
CmpI,
|
||||
Select,
|
||||
ConstantGlobalLoad
|
||||
};
|
||||
|
||||
Kind kind = Kind::Constant;
|
||||
int64_t constant = 0;
|
||||
mlir::Value symbol;
|
||||
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t, 4> globalStrides;
|
||||
llvm::SmallVector<CompiledIndexExpr, 4> operands;
|
||||
};
|
||||
|
||||
struct CompiledAddressExpr {
|
||||
mlir::Value base;
|
||||
CompiledIndexExpr byteOffset;
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress>
|
||||
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
|
||||
};
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
/// Resolves a value to contiguous backing storage when that storage can be
|
||||
@@ -35,9 +84,12 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value
|
||||
/// arithmetic and loop facts recorded in `knowledge`.
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
|
||||
|
||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||
/// loop-carried memref/result.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t>
|
||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
llvm::SmallVector<int32_t>
|
||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -78,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
if (*step <= 0) {
|
||||
forOp.emitOpError("requires positive scf.for step for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 2> samples;
|
||||
if (*lowerBound < *upperBound) {
|
||||
samples.push_back(*lowerBound);
|
||||
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
|
||||
if (last != *lowerBound)
|
||||
samples.push_back(last);
|
||||
}
|
||||
|
||||
for (int64_t inductionValue : samples) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -21,4 +21,13 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
/// Walks a `pim.core`-like body structurally for verification without
|
||||
/// enumerating full loop trip counts. Loop bounds must still be statically
|
||||
/// evaluable so address resolution remains well-defined.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlockStructurally(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)>
|
||||
callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -117,4 +117,22 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,9 +3,15 @@
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -26,4 +32,24 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user