From e8a08f6dd0bd8519c09e75ef9536911fbc43befa Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 25 May 2026 15:24:12 +0200 Subject: [PATCH] faster pim VerificationPass.cpp and pim code emission --- src/PIM/Common/CMakeLists.txt | 1 + src/PIM/Common/IR/AddressAnalysis.cpp | 431 ++++++++ src/PIM/Common/IR/AddressAnalysis.hpp | 52 + src/PIM/Common/IR/BatchCoreUtils.cpp | 22 + src/PIM/Common/IR/BatchCoreUtils.hpp | 15 + src/PIM/Common/IR/CoreBlockUtils.cpp | 56 ++ src/PIM/Common/IR/CoreBlockUtils.hpp | 9 + src/PIM/Common/IR/WeightUtils.cpp | 18 + src/PIM/Common/IR/WeightUtils.hpp | 26 + src/PIM/Compiler/CMakeLists.txt | 1 - src/PIM/Compiler/PimArtifactWriter.cpp | 2 +- src/PIM/Compiler/PimBatchEmission.cpp | 193 ---- src/PIM/Compiler/PimBatchEmission.hpp | 16 - src/PIM/Compiler/PimCodeGen.cpp | 942 +++++++++++++----- src/PIM/Compiler/PimCodeGen.hpp | 81 +- src/PIM/Compiler/PimWeightEmitter.cpp | 177 ++-- .../StaticMemoryCoalescingPass.cpp | 7 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 134 ++- 18 files changed, 1610 insertions(+), 573 deletions(-) create mode 100644 src/PIM/Common/IR/BatchCoreUtils.cpp create mode 100644 src/PIM/Common/IR/BatchCoreUtils.hpp delete mode 100644 src/PIM/Compiler/PimBatchEmission.cpp delete mode 100644 src/PIM/Compiler/PimBatchEmission.hpp diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 8f62537..1bdc67f 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -1,5 +1,6 @@ add_pim_library(OMPimCommon IR/AddressAnalysis.cpp + IR/BatchCoreUtils.cpp IR/ConstantUtils.cpp IR/CoreBlockUtils.cpp IR/EntryPointUtils.cpp diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index 3019749..9aa664c 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -32,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg return value; } +llvm::FailureOr compileIndexValueImpl(mlir::Value value); +llvm::FailureOr compileContiguousAddressExprImpl(mlir::Value value); + +template +CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) { + return CompiledIndexExpr(std::make_shared(std::forward(args)...)); +} + mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); @@ -128,6 +136,225 @@ static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t l llvm_unreachable("unknown cmpi predicate"); } +llvm::FailureOr evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) { + if (!expr.node) + return mlir::failure(); + + switch (expr.node->kind) { + case CompiledIndexExprNode::Kind::Constant: + return expr.node->constant; + case CompiledIndexExprNode::Kind::Symbol: { + auto value = resolveAlias(expr.node->symbol, &knowledge); + auto iter = knowledge.indexValues.find(value); + if (iter != knowledge.indexValues.end()) + return iter->second; + return mlir::failure(); + } + case CompiledIndexExprNode::Kind::Add: + case CompiledIndexExprNode::Kind::Sub: + case CompiledIndexExprNode::Kind::Mul: + case CompiledIndexExprNode::Kind::DivUI: + case CompiledIndexExprNode::Kind::DivSI: + case CompiledIndexExprNode::Kind::RemUI: + case CompiledIndexExprNode::Kind::RemSI: + case CompiledIndexExprNode::Kind::MinUI: + case CompiledIndexExprNode::Kind::CmpI: { + auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge); + auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + + switch (expr.node->kind) { + case CompiledIndexExprNode::Kind::Add: + return *lhs + *rhs; + case CompiledIndexExprNode::Kind::Sub: + return *lhs - *rhs; + case CompiledIndexExprNode::Kind::Mul: + return *lhs * *rhs; + case CompiledIndexExprNode::Kind::DivUI: + if (*rhs == 0) + return mlir::failure(); + return static_cast(static_cast(*lhs) / static_cast(*rhs)); + case CompiledIndexExprNode::Kind::DivSI: + if (*rhs == 0 || (*lhs == std::numeric_limits::min() && *rhs == -1)) + return mlir::failure(); + return *lhs / *rhs; + case CompiledIndexExprNode::Kind::RemUI: + if (*rhs == 0) + return mlir::failure(); + return static_cast(static_cast(*lhs) % static_cast(*rhs)); + case CompiledIndexExprNode::Kind::RemSI: + if (*rhs == 0) + return mlir::failure(); + if (*lhs == std::numeric_limits::min() && *rhs == -1) + return 0; + return *lhs % *rhs; + case CompiledIndexExprNode::Kind::MinUI: + return static_cast(std::min(static_cast(*lhs), static_cast(*rhs))); + case CompiledIndexExprNode::Kind::CmpI: + return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0; + default: + llvm_unreachable("unexpected binary compiled index kind"); + } + } + case CompiledIndexExprNode::Kind::Select: { + auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge); + if (failed(condition)) + return mlir::failure(); + return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge); + } + case CompiledIndexExprNode::Kind::ConstantGlobalLoad: { + if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue()) + return mlir::failure(); + + auto denseAttr = mlir::dyn_cast(*expr.node->globalOp.getInitialValue()); + auto globalType = mlir::dyn_cast(expr.node->globalOp.getType()); + if (!denseAttr || !globalType) + return mlir::failure(); + + llvm::SmallVector indices; + indices.reserve(expr.node->operands.size()); + for (const CompiledIndexExpr& operand : expr.node->operands) { + auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge); + if (failed(resolvedIndex)) + return mlir::failure(); + indices.push_back(*resolvedIndex); + } + + int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides); + if (linearIndex < 0 || linearIndex >= globalType.getNumElements()) + return mlir::failure(); + return denseAttr.getValues()[linearIndex].getSExtValue(); + } + } + + llvm_unreachable("unknown compiled index kind"); +} + +llvm::FailureOr compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) { + auto getGlobalOp = loadOp.getMemRef().getDefiningOp(); + if (!getGlobalOp) + return mlir::failure(); + + auto moduleOp = loadOp->getParentOfType(); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) + return mlir::failure(); + + auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); + auto globalType = mlir::dyn_cast(getGlobalOp.getType()); + if (!denseAttr || !globalType || !globalType.hasStaticShape()) + return mlir::failure(); + + auto elementType = denseAttr.getElementType(); + if (!elementType.isIndex() && !elementType.isInteger()) + return mlir::failure(); + + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad; + expr.globalOp = globalOp; + expr.globalStrides = computeRowMajorStrides(globalType.getShape()); + expr.operands.reserve(loadOp.getIndices().size()); + for (mlir::Value index : loadOp.getIndices()) { + auto compiledIndex = compileIndexValueImpl(index); + if (failed(compiledIndex)) + return mlir::failure(); + expr.operands.push_back(*compiledIndex); + } + return makeCompiledIndexExpr(std::move(expr)); +} + +llvm::FailureOr compileIndexValueImpl(mlir::Value value) { + if (auto constantOp = value.getDefiningOp()) { + if (auto integerAttr = mlir::dyn_cast(constantOp.getValue())) { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = integerAttr.getInt(); + return makeCompiledIndexExpr(std::move(expr)); + } + } + + mlir::Operation* definingOp = value.getDefiningOp(); + if (!definingOp) { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Symbol; + expr.symbol = value; + return makeCompiledIndexExpr(std::move(expr)); + } + + auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) { + auto lhs = compileIndexValueImpl(lhsValue); + auto rhs = compileIndexValueImpl(rhsValue); + if (failed(lhs) || failed(rhs)) + return llvm::FailureOr(mlir::failure()); + CompiledIndexExprNode expr; + expr.kind = kind; + expr.operands = {*lhs, *rhs}; + return llvm::FailureOr(makeCompiledIndexExpr(std::move(expr))); + }; + + if (auto indexCastOp = mlir::dyn_cast(definingOp)) + return compileIndexValueImpl(indexCastOp.getIn()); + if (auto addOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs()); + if (auto subOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs()); + if (auto mulOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs()); + if (auto divOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs()); + if (auto divOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs()); + if (auto minOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs()); + if (auto remOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs()); + if (auto remOp = mlir::dyn_cast(definingOp)) + return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs()); + if (auto cmpOp = mlir::dyn_cast(definingOp)) { + auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs()); + if (failed(expr)) + return mlir::failure(); + auto exprNode = std::make_shared(*expr->node); + exprNode->predicate = cmpOp.getPredicate(); + return CompiledIndexExpr(exprNode); + } + if (auto maxOp = mlir::dyn_cast(definingOp)) { + auto lhs = compileIndexValueImpl(maxOp.getLhs()); + auto rhs = compileIndexValueImpl(maxOp.getRhs()); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + + CompiledIndexExprNode cmpExpr; + cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI; + cmpExpr.predicate = mlir::arith::CmpIPredicate::uge; + cmpExpr.operands = {*lhs, *rhs}; + + CompiledIndexExprNode selectExpr; + selectExpr.kind = CompiledIndexExprNode::Kind::Select; + selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs}; + return makeCompiledIndexExpr(std::move(selectExpr)); + } + if (auto selectOp = mlir::dyn_cast(definingOp)) { + auto condition = compileIndexValueImpl(selectOp.getCondition()); + auto trueValue = compileIndexValueImpl(selectOp.getTrueValue()); + auto falseValue = compileIndexValueImpl(selectOp.getFalseValue()); + if (failed(condition) || failed(trueValue) || failed(falseValue)) + return mlir::failure(); + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Select; + expr.operands = {*condition, *trueValue, *falseValue}; + return makeCompiledIndexExpr(std::move(expr)); + } + if (auto loadOp = mlir::dyn_cast(definingOp)) + return compileConstantGlobalLoad(loadOp); + + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Symbol; + expr.symbol = value; + return makeCompiledIndexExpr(std::move(expr)); +} + llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); @@ -353,6 +580,191 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va } } +llvm::FailureOr compileContiguousAddressExprImpl(mlir::Value value) { + int64_t constantByteOffset = 0; + CompiledIndexExpr byteOffsetExpr; + { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = 0; + byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); + } + + while (true) { + if (mlir::isa(value)) + return CompiledAddressExpr {value, byteOffsetExpr}; + + mlir::Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return mlir::failure(); + + if (auto dpsDefiningOp = mlir::dyn_cast(definingOp)) { + mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast(value)); + if (!tiedOperand) + return mlir::failure(); + value = tiedOperand->get(); + continue; + } + + if (auto forOp = mlir::dyn_cast(definingOp)) { + auto result = mlir::dyn_cast(value); + if (!result) + return mlir::failure(); + + auto yieldOp = mlir::cast(forOp.getBody()->getTerminator()); + mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber()); + if (auto blockArgument = mlir::dyn_cast(yieldedValue)) { + if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 + && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { + value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1]; + continue; + } + } + + value = yieldedValue; + continue; + } + + if (auto subviewOp = mlir::dyn_cast(definingOp)) { + auto sourceType = mlir::dyn_cast(subviewOp.getSource().getType()); + auto subviewType = mlir::dyn_cast(subviewOp.getType()); + if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) + return mlir::failure(); + + llvm::SmallVector staticOffsets; + staticOffsets.reserve(subviewOp.getMixedOffsets().size()); + llvm::SmallVector staticSizes; + staticSizes.reserve(subviewOp.getMixedSizes().size()); + llvm::SmallVector staticStrides; + staticStrides.reserve(subviewOp.getMixedStrides().size()); + bool allStatic = true; + + for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) { + if (auto attr = mlir::dyn_cast(offset)) + staticOffsets.push_back(mlir::cast(attr).getInt()); + else + allStatic = false; + } + for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) { + if (auto attr = mlir::dyn_cast(size)) + staticSizes.push_back(mlir::cast(attr).getInt()); + else + allStatic = false; + } + for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) { + if (auto attr = mlir::dyn_cast(stride)) + staticStrides.push_back(mlir::cast(attr).getInt()); + else + allStatic = false; + } + + if (allStatic) { + if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) + return mlir::failure(); + + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + constantByteOffset += + linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); + } + else { + llvm::SmallVector sourceStrides = computeRowMajorStrides(sourceType.getShape()); + CompiledIndexExpr offsetExpr; + { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = 0; + offsetExpr = makeCompiledIndexExpr(std::move(expr)); + } + + for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) { + CompiledIndexExpr operandExpr; + if (auto attr = mlir::dyn_cast(mixedOffset)) { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = mlir::cast(attr).getInt() * sourceStride + * getElementTypeSizeInBytes(subviewType.getElementType()); + operandExpr = makeCompiledIndexExpr(std::move(expr)); + } + else { + auto compiledOffset = compileIndexValueImpl(mlir::cast(mixedOffset)); + if (failed(compiledOffset)) + return mlir::failure(); + CompiledIndexExpr scaleExpr; + { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType()); + scaleExpr = makeCompiledIndexExpr(std::move(expr)); + } + + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Mul; + expr.operands = {*compiledOffset, scaleExpr}; + operandExpr = makeCompiledIndexExpr(std::move(expr)); + } + + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Add; + expr.operands = {offsetExpr, operandExpr}; + offsetExpr = makeCompiledIndexExpr(std::move(expr)); + } + + CompiledIndexExpr constantExpr; + { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = constantByteOffset; + constantExpr = makeCompiledIndexExpr(std::move(expr)); + } + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Add; + expr.operands = {constantExpr, offsetExpr}; + byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); + constantByteOffset = 0; + } + + value = subviewOp.getSource(); + continue; + } + + if (auto castOp = mlir::dyn_cast(definingOp)) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = mlir::dyn_cast(definingOp)) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = mlir::dyn_cast(definingOp)) { + value = expandOp.getSrc(); + continue; + } + + if (mlir::isa(definingOp)) { + if (constantByteOffset != 0) { + CompiledIndexExpr constantExpr; + { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = constantByteOffset; + constantExpr = makeCompiledIndexExpr(std::move(expr)); + } + if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0) + byteOffsetExpr = constantExpr; + else { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Add; + expr.operands = {constantExpr, byteOffsetExpr}; + byteOffsetExpr = makeCompiledIndexExpr(std::move(expr)); + } + } + return CompiledAddressExpr {value, byteOffsetExpr}; + } + + return mlir::failure(); + } +} + } // namespace llvm::FailureOr resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); } @@ -361,6 +773,8 @@ llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueK return resolveIndexValueImpl(value, &knowledge); } +llvm::FailureOr compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); } + llvm::FailureOr resolveContiguousAddress(mlir::Value value) { return resolveContiguousAddressImpl(value, nullptr); } @@ -374,4 +788,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg return resolveLoopCarriedAliasImpl(value, &knowledge); } +llvm::FailureOr compileContiguousAddressExpr(mlir::Value value) { + return compileContiguousAddressExprImpl(value); +} + +llvm::FailureOr CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const { + return evaluateCompiledIndexExpr(*this, knowledge); +} + +llvm::FailureOr +CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional lane) const { + (void) lane; + auto resolvedOffset = byteOffset.evaluate(knowledge); + if (failed(resolvedOffset)) + return mlir::failure(); + return ResolvedContiguousAddress {base, *resolvedOffset}; +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/AddressAnalysis.hpp b/src/PIM/Common/IR/AddressAnalysis.hpp index eb099bd..9b73e6e 100644 --- a/src/PIM/Common/IR/AddressAnalysis.hpp +++ b/src/PIM/Common/IR/AddressAnalysis.hpp @@ -1,10 +1,14 @@ #pragma once +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" +#include +#include + namespace onnx_mlir { /// Describes a value as a base addressable object plus a statically known @@ -23,6 +27,51 @@ struct StaticValueKnowledge { StaticValueKnowledge() {} }; +struct CompiledIndexExprNode; + +struct CompiledIndexExpr { + std::shared_ptr node; + + CompiledIndexExpr() = default; + explicit CompiledIndexExpr(std::shared_ptr node) : node(std::move(node)) {} + + llvm::FailureOr evaluate(const StaticValueKnowledge& knowledge) const; +}; + +struct CompiledIndexExprNode { + enum class Kind { + Constant, + Symbol, + Add, + Sub, + Mul, + DivUI, + DivSI, + RemUI, + RemSI, + MinUI, + CmpI, + Select, + ConstantGlobalLoad + }; + + Kind kind = Kind::Constant; + int64_t constant = 0; + mlir::Value symbol; + mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq; + mlir::memref::GlobalOp globalOp; + llvm::SmallVector globalStrides; + llvm::SmallVector operands; +}; + +struct CompiledAddressExpr { + mlir::Value base; + CompiledIndexExpr byteOffset; + + llvm::FailureOr + evaluate(const StaticValueKnowledge& knowledge, std::optional lane) const; +}; + mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); /// Resolves a value to contiguous backing storage when that storage can be @@ -35,9 +84,12 @@ llvm::FailureOr resolveContiguousAddress(mlir::Value /// arithmetic and loop facts recorded in `knowledge`. llvm::FailureOr resolveIndexValue(mlir::Value value); llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); +llvm::FailureOr compileIndexExpr(mlir::Value value); /// Follows alias, view, and DPS chains to recover the backing value of a /// loop-carried memref/result. mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); +llvm::FailureOr compileContiguousAddressExpr(mlir::Value value); + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/BatchCoreUtils.cpp b/src/PIM/Common/IR/BatchCoreUtils.cpp new file mode 100644 index 0000000..bc1a837 --- /dev/null +++ b/src/PIM/Common/IR/BatchCoreUtils.cpp @@ -0,0 +1,22 @@ +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + +namespace onnx_mlir { + +llvm::SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { + auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); + return llvm::SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); +} + +llvm::SmallVector +getLaneChunkCoreIds(llvm::ArrayRef coreIds, size_t laneCount, unsigned lane) { + llvm::SmallVector laneCoreIds; + laneCoreIds.reserve(coreIds.size() / laneCount); + for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) + laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]); + return laneCoreIds; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/BatchCoreUtils.hpp b/src/PIM/Common/IR/BatchCoreUtils.hpp new file mode 100644 index 0000000..0b92644 --- /dev/null +++ b/src/PIM/Common/IR/BatchCoreUtils.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +namespace onnx_mlir { + +llvm::SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp); + +llvm::SmallVector +getLaneChunkCoreIds(llvm::ArrayRef coreIds, size_t laneCount, unsigned lane); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/CoreBlockUtils.cpp b/src/PIM/Common/IR/CoreBlockUtils.cpp index 6c104ac..03bdc8f 100644 --- a/src/PIM/Common/IR/CoreBlockUtils.cpp +++ b/src/PIM/Common/IR/CoreBlockUtils.cpp @@ -2,6 +2,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/ADT/SmallVector.h" + #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -78,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block, return mlir::success(!hasFailure); } +mlir::LogicalResult walkPimCoreBlockStructurally( + mlir::Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref callback) { + bool hasFailure = false; + for (mlir::Operation& op : block) { + if (mlir::isa(op) || isCoreStaticAddressOp(&op)) + continue; + if (auto loadOp = mlir::dyn_cast(op); + loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge))) + continue; + + if (auto forOp = mlir::dyn_cast(op)) { + mlir::Block& loopBody = forOp.getRegion().front(); + auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge); + auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge); + auto step = resolveIndexValue(forOp.getStep(), knowledge); + if (failed(lowerBound) || failed(upperBound) || failed(step)) { + forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification"); + hasFailure = true; + continue; + } + if (*step <= 0) { + forOp.emitOpError("requires positive scf.for step for PIM verification"); + hasFailure = true; + continue; + } + + llvm::SmallVector samples; + if (*lowerBound < *upperBound) { + samples.push_back(*lowerBound); + int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step; + if (last != *lowerBound) + samples.push_back(last); + } + + for (int64_t inductionValue : samples) { + StaticValueKnowledge loopKnowledge = knowledge; + loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue; + for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs())) + loopKnowledge.aliases[iterArg] = iterValue; + + if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback))) + hasFailure = true; + } + continue; + } + + if (failed(callback(op, knowledge))) + hasFailure = true; + } + return mlir::success(!hasFailure); +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/CoreBlockUtils.hpp b/src/PIM/Common/IR/CoreBlockUtils.hpp index 91fb7cf..a002098 100644 --- a/src/PIM/Common/IR/CoreBlockUtils.hpp +++ b/src/PIM/Common/IR/CoreBlockUtils.hpp @@ -21,4 +21,13 @@ walkPimCoreBlock(mlir::Block& block, const StaticValueKnowledge& knowledge, llvm::function_ref callback); +/// Walks a `pim.core`-like body structurally for verification without +/// enumerating full loop trip counts. Loop bounds must still be statically +/// evaluable so address resolution remains well-defined. +mlir::LogicalResult +walkPimCoreBlockStructurally(mlir::Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref + callback); + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index a206f3a..9ba9f21 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -117,4 +117,22 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) { + if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { + for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) + if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) + return weightIndex; + return std::nullopt; + } + + if (auto coreBatchOp = mlir::dyn_cast_or_null(weightOwner)) { + for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) + if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) + return weightIndex; + return std::nullopt; + } + + return std::nullopt; +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.hpp b/src/PIM/Common/IR/WeightUtils.hpp index f0a1b2f..c02c839 100644 --- a/src/PIM/Common/IR/WeightUtils.hpp +++ b/src/PIM/Common/IR/WeightUtils.hpp @@ -3,9 +3,15 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/StringRef.h" +#include + +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { @@ -26,4 +32,24 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); /// passes can identify globals that must remain weight-backed. void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); +template +llvm::SmallVector getUsedWeightIndices(CoreLikeOpTy coreLikeOp) { + llvm::SmallVector indices; + auto addWeight = [&](mlir::Value weight) { + for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) { + if (coreLikeOp.getWeightArgument(weightIndex) != weight) + continue; + if (!llvm::is_contained(indices, weightIndex)) + indices.push_back(weightIndex); + return; + } + }; + + coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); + llvm::sort(indices); + return indices; +} + +std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp); + } // namespace onnx_mlir diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt index a578c12..c40bac1 100644 --- a/src/PIM/Compiler/CMakeLists.txt +++ b/src/PIM/Compiler/CMakeLists.txt @@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions add_pim_library(OMPimCompilerUtils PimCompilerUtils.cpp PimArtifactWriter.cpp - PimBatchEmission.cpp PimCodeGen.cpp PimWeightEmitter.cpp diff --git a/src/PIM/Compiler/PimArtifactWriter.cpp b/src/PIM/Compiler/PimArtifactWriter.cpp index f04211f..faa9c01 100644 --- a/src/PIM/Compiler/PimArtifactWriter.cpp +++ b/src/PIM/Compiler/PimArtifactWriter.cpp @@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& if (!denseAttr) return; - MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult()); + MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt}); ArrayRef rawData = denseAttr.getRawData(); char* dst = memoryBuffer.data() + memEntry.address; diff --git a/src/PIM/Compiler/PimBatchEmission.cpp b/src/PIM/Compiler/PimBatchEmission.cpp deleted file mode 100644 index 20b5262..0000000 --- a/src/PIM/Compiler/PimBatchEmission.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" - -#include "llvm/ADT/StringRef.h" - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" - -using namespace mlir; - -namespace onnx_mlir { -namespace { - -static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { - auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); - assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); -} - -static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_t laneCount, unsigned lane) { - SmallVector laneCoreIds; - laneCoreIds.reserve(coreIds.size() / laneCount); - for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) - laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]); - return laneCoreIds; -} - -static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) { - if (Value mapped = mapper.lookupOrNull(value)) - return mapped; - - if (auto blockArgument = dyn_cast(value)) { - assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning"); - assert(false && "unexpected captured block argument while scalarizing pim.core_batch"); - } - - Operation* definingOp = value.getDefiningOp(); - assert(definingOp && "expected captured value to be defined by an operation"); - assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning"); - - for (Value operand : definingOp->getOperands()) - (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); - - Operation* cloned = builder.clone(*definingOp, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - return mapper.lookup(value); -} - -static void cloneScalarizedLaneBody(OpBuilder& builder, - pim::PimCoreBatchOp coreBatchOp, - unsigned lane, - OperationFolder& constantFolder) { - Block& oldBlock = coreBatchOp.getBody().front(); - Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); - size_t laneCount = static_cast(coreBatchOp.getLaneCount()); - size_t weightCount = coreBatchOp.getWeights().size(); - - IRMapping mapper; - for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { - if (blockArg.getType().isIndex()) { - mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast(lane), constantFolder)); - continue; - } - - if (argIndex <= weightCount) { - auto scalarCoreOp = cast(anchorOp); - mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1)); - continue; - } - - size_t inputIndex = argIndex - 1 - weightCount; - assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range"); - mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]); - } - - for (Operation& op : oldBlock) { - if (isa(op)) - continue; - - for (Value operand : op.getOperands()) - (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); - - if (auto sendBatchOp = dyn_cast(op)) { - pim::PimSendOp::create( - builder, - sendBatchOp.getLoc(), - mapper.lookup(sendBatchOp.getInput()), - sendBatchOp.getSizeAttr(), - getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder)); - continue; - } - - if (auto sendTensorBatchOp = dyn_cast(op)) { - pim::PimSendTensorOp::create( - builder, - sendTensorBatchOp.getLoc(), - mapper.lookup(sendTensorBatchOp.getInput()), - builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); - continue; - } - - if (auto receiveBatchOp = dyn_cast(op)) { - auto scalarReceive = pim::PimReceiveOp::create( - builder, - receiveBatchOp.getLoc(), - receiveBatchOp.getOutput().getType(), - mapper.lookup(receiveBatchOp.getOutputBuffer()), - receiveBatchOp.getSizeAttr(), - getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder)); - mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); - continue; - } - - if (auto receiveTensorBatchOp = dyn_cast(op)) { - auto scalarReceive = pim::PimReceiveTensorOp::create( - builder, - receiveTensorBatchOp.getLoc(), - receiveTensorBatchOp.getOutput().getType(), - mapper.lookup(receiveTensorBatchOp.getOutputBuffer()), - builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); - mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput()); - continue; - } - - if (auto memcpBatchOp = dyn_cast(op)) { - auto scalarCopy = pim::PimMemCopyHostToDevOp::create( - builder, - memcpBatchOp.getLoc(), - memcpBatchOp.getOutput().getType(), - getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), - getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder), - mapper.lookup(memcpBatchOp.getDeviceTarget()), - mapper.lookup(memcpBatchOp.getHostSource()), - memcpBatchOp.getSizeAttr()); - mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); - continue; - } - - Operation* cloned = builder.clone(op, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - } -} - -} // namespace - -LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, - ArrayRef lanes, - llvm::function_ref callback) { - assert(!lanes.empty() && "expected at least one batch lane"); - - OwningOpRef scratchModule = ModuleOp::create(coreBatchOp.getLoc()); - OpBuilder builder(scratchModule->getContext()); - OperationFolder constantFolder(scratchModule->getContext()); - builder.setInsertionPointToStart(scratchModule->getBody()); - - SmallVector weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end()); - auto coreIds = getBatchCoreIds(coreBatchOp); - int32_t coreId = coreIds[lanes.front()]; - for (unsigned lane : lanes) - assert(coreIds[lane] == coreId && "all grouped lanes must target the same core"); - - auto scalarCore = - pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); - SmallVector weightTypes; - SmallVector weightLocs; - weightTypes.reserve(weights.size()); - weightLocs.reserve(weights.size()); - for (Value weight : weights) { - weightTypes.push_back(weight.getType()); - weightLocs.push_back(weight.getLoc()); - } - Block* block = - builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs); - builder.setInsertionPointToEnd(block); - for (unsigned lane : lanes) - cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); - if (block->empty() || !isa(block->back())) - pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); - return callback(scalarCore); -} - -LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, - unsigned lane, - llvm::function_ref callback) { - return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef {lane}, callback); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimBatchEmission.hpp b/src/PIM/Compiler/PimBatchEmission.hpp deleted file mode 100644 index 1977d55..0000000 --- a/src/PIM/Compiler/PimBatchEmission.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include "llvm/ADT/STLFunctionalExtras.h" - -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" - -namespace onnx_mlir { - -mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, - unsigned lane, - llvm::function_ref callback); -mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, - llvm::ArrayRef lanes, - llvm::function_ref callback); - -} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 3897e52..e576b78 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -1,14 +1,18 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Threading.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" @@ -21,6 +25,7 @@ #include #include #include +#include #include #include @@ -28,8 +33,9 @@ #include "Common/PimCommon.hpp" #include "Common/Support/ReportUtils.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" -#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -41,29 +47,56 @@ using namespace mlir; using namespace onnx_mlir; using namespace onnx_mlir::compact_asm; -MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { +namespace { + +static std::optional getLaneForMemoryValue(mlir::Value value, std::optional lane) { + if (!lane) + return std::nullopt; + auto allocOp = value.getDefiningOp(); + if (!allocOp || !allocOp->getParentOfType()) + return std::nullopt; + return lane; +} + +static mlir::Value resolveCachedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) { + auto iter = knowledge.aliases.find(value); + while (iter != knowledge.aliases.end()) { + value = iter->second; + iter = knowledge.aliases.find(value); + } + return value; +} + +static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional lane = std::nullopt) { + return {value, getLaneForMemoryValue(value, lane)}; +} + +} // namespace + +MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional lane) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); size_t allocSize = getShapedTypeSizeInBytes(type); MemEntry memEntry = {0, allocSize}; - return &memEntries.emplace_back(memEntry, value).first; + return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first; } void PimMemory::allocateGatheredMemory() { llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); - for (auto& [memEntry, value] : memEntries) - allocateMemoryForValue(value, memEntry); + for (auto& [memEntry, key] : memEntries) + allocateMemoryForValue(key, memEntry); + memEntries.clear(); } -void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { +void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry) { memEntry.address = firstAvailableAddress; firstAvailableAddress += memEntry.size; // Alignment if (size_t remainder = firstAvailableAddress % minAlignment) firstAvailableAddress += minAlignment - remainder; - ownedMemEntriesMap[value] = memEntry; - globalMemEntriesMap[value] = memEntry; + ownedMemEntriesMap[key] = memEntry; + globalMemEntriesMap[key] = memEntry; } void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { @@ -101,11 +134,11 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { allocateGatheredMemory(); for (auto [alias, original] : globalAliases) - globalMemEntriesMap[alias] = getMemEntry(original); + globalMemEntriesMap[getMemoryValueKey(alias)] = getMemEntry(getMemoryValueKey(original)); } -void PimMemory::allocateCore(Operation* op) { - op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); +void PimMemory::allocateCore(Operation* op, std::optional lane) { + op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp, lane); }); allocateGatheredMemory(); } @@ -149,8 +182,8 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem MemoryReportRow PimMemory::getReportRow() const { MemoryReportRow row; - for (auto& [val, memEntry] : ownedMemEntriesMap) { - if (auto op = val.getDefiningOp()) { + for (auto& [key, memEntry] : ownedMemEntriesMap) { + if (auto op = key.value.getDefiningOp()) { if (isa(op)) { row.numAlloca++; row.sizeAlloca += memEntry.size; @@ -166,14 +199,26 @@ MemoryReportRow PimMemory::getReportRow() const { } void PimMemory::remove(mlir::Value val) { - if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end()) - ownedMemEntriesMap.erase(removeIter); - if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end()) - globalMemEntriesMap.erase(removeIter); + for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) { + if (it->first.value == val) { + auto eraseIt = it++; + ownedMemEntriesMap.erase(eraseIt); + } + else + ++it; + } + for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();) { + if (it->first.value == val) { + auto eraseIt = it++; + globalMemEntriesMap.erase(eraseIt); + } + else + ++it; + } } -MemEntry PimMemory::getMemEntry(mlir::Value value) const { - auto iter = globalMemEntriesMap.find(value); +MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const { + auto iter = globalMemEntriesMap.find(key); assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); return iter->second; } @@ -182,10 +227,25 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { return deviceMem.try_emplace(id, memEntriesMap).first->second; } -size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const { - auto resolvedAddress = resolveContiguousAddress(value, knowledge); +size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, + const StaticValueKnowledge& knowledge, + std::optional lane) const { + value = resolveCachedAlias(value, knowledge); + auto compiledIt = compiledAddressExprs.find(value); + if (compiledIt == compiledAddressExprs.end()) { + auto compiledExpr = compileContiguousAddressExpr(value); + if (failed(compiledExpr)) { + errs() << "Failed to compile contiguous address for value: "; + value.print(errs()); + errs() << "\n"; + llvm_unreachable("Failed to compile contiguous address"); + } + compiledIt = compiledAddressExprs.try_emplace(value, *compiledExpr).first; + } + + auto resolvedAddress = compiledIt->second.evaluate(knowledge, lane); if (failed(resolvedAddress)) { - errs() << "Failed to resolve contiguous address for value: "; + errs() << "Failed to evaluate contiguous address for value: "; value.print(errs()); errs() << "\n"; if (auto* definingOp = value.getDefiningOp()) { @@ -196,11 +256,14 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu llvm_unreachable("Failed to resolve contiguous address"); } - auto iter = memEntriesMap.find(resolvedAddress->base); + MemoryValueKey key = getMemoryValueKey(resolvedAddress->base, lane); + auto iter = memEntriesMap.find(key); if (iter == memEntriesMap.end()) { errs() << "Missing mem entry for value: "; resolvedAddress->base.print(errs()); errs() << "\n"; + if (key.lane) + errs() << "Lane: " << *key.lane << "\n"; if (auto* definingOp = resolvedAddress->base.getDefiningOp()) { errs() << "Defining op:\n"; definingOp->print(errs()); @@ -212,6 +275,18 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu return iter->second.address + resolvedAddress->byteOffset; } +llvm::FailureOr PimAcceleratorMemory::getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) const { + value = resolveCachedAlias(value, knowledge); + auto compiledIt = compiledIndexExprs.find(value); + if (compiledIt == compiledIndexExprs.end()) { + auto compiledExpr = compileIndexExpr(value); + if (failed(compiledExpr)) + return mlir::failure(); + compiledIt = compiledIndexExprs.try_emplace(value, *compiledExpr).first; + } + return compiledIt->second.evaluate(knowledge); +} + void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); } void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) { @@ -393,8 +468,8 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_ } void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const { - auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge); - auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge); + auto deviceTargetOffset = indexOf(loadOp.getDeviceTargetOffset(), knowledge); + auto hostSourceOffset = indexOf(loadOp.getHostSourceOffset(), knowledge); assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset) && "pim.memcp_hd offsets must be statically resolvable during codegen"); emitMemCopyOp("ld", @@ -416,8 +491,8 @@ void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, } void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { - auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge); - auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge); + auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge); + auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge); assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset) && "pim.memcp_dh offsets must be statically resolvable during codegen"); emitMemCopyOp("st", @@ -439,7 +514,7 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg } void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const { - auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge); + auto sourceCoreId = indexOf(receiveOp.getSourceCoreId(), knowledge); assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen"); emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize()); } @@ -453,8 +528,25 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); } +void PimCodeGen::codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, + unsigned lane, + const StaticValueKnowledge& knowledge) const { + emitCommunicationOp( + "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreIds()[lane], receiveOp.getSize()); +} + +void PimCodeGen::codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp, + ArrayRef laneCoreIds, + const StaticValueKnowledge& knowledge) const { + size_t outputAddr = addressOf(receiveOp.getOutputBuffer(), knowledge); + size_t chunkSize = getShapedTypeSizeInBytes(cast(receiveOp.getOutputBuffer().getType())) + / laneCoreIds.size(); + for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(laneCoreIds)) + emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); +} + void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { - auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge); + auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge); assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen"); emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize()); } @@ -467,6 +559,21 @@ void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const St emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); } +void PimCodeGen::codeGenSendBatchOp(pim::PimSendBatchOp sendOp, + unsigned lane, + const StaticValueKnowledge& knowledge) const { + emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreIds()[lane], sendOp.getSize()); +} + +void PimCodeGen::codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp, + ArrayRef laneCoreIds, + const StaticValueKnowledge& knowledge) const { + size_t inputAddr = addressOf(sendOp.getInput(), knowledge); + size_t chunkSize = getShapedTypeSizeInBytes(cast(sendOp.getInput().getType())) / laneCoreIds.size(); + for (auto [chunkIndex, targetCoreId] : llvm::enumerate(laneCoreIds)) + emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); +} + void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const { auto outputType = cast(concatOp.getOutputBuffer().getType()); assert(outputType.hasStaticShape() && "concat codegen requires static output shape"); @@ -745,35 +852,6 @@ std::string getMemorySizeAsString(size_t size) { return std::to_string(size) + " Bytes"; } -static SmallVector getUsedWeightIndices(Block& block) { - SmallVector indices; - auto coreOp = dyn_cast(block.getParentOp()); - auto addWeight = [&](mlir::Value weight) { - if (!coreOp) - return; - for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) { - if (coreOp.getWeightArgument(weightIndex) != weight) - continue; - if (!llvm::is_contained(indices, weightIndex)) - indices.push_back(weightIndex); - return; - } - }; - block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); - llvm::sort(indices); - return indices; -} - -static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { - return getUsedWeightIndices(coreOp.getBody().front()); -} - -static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { - auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); - assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); -} - static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { SmallVector coreLikeOps; for (Operation& op : funcOp.getBody().front()) @@ -782,6 +860,341 @@ static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { return coreLikeOps; } +struct CoreEmissionResult { + OnnxMlirCompilerErrorCodes status = CompilerSuccess; + MemoryReportRow reportRow; + llvm::SmallVector usedWeightIndices; +}; + +template +class ScopedMapBindings { + using KeyTy = typename MapTy::key_type; + using ValueTy = typename MapTy::mapped_type; + + MapTy& map; + llvm::SmallVector>, 8> savedEntries; + +public: + explicit ScopedMapBindings(MapTy& map) : map(map) {} + + void bind(const KeyTy& key, const ValueTy& value) { + auto it = map.find(key); + if (it == map.end()) + savedEntries.emplace_back(key, std::nullopt); + else + savedEntries.emplace_back(key, it->second); + map[key] = value; + } + + ~ScopedMapBindings() { + for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it) { + if (it->second) + map[it->first] = *it->second; + else + map.erase(it->first); + } + } +}; + +enum class CompiledCoreOpKind : uint8_t { + Load, + LoadBatch, + Store, + Lmv, + Receive, + ReceiveBatch, + ReceiveTensor, + ReceiveTensorBatch, + Send, + SendBatch, + SendTensor, + SendTensorBatch, + Concat, + Vmm, + Transpose, + VVAdd, + VVSub, + VVMul, + VVMax, + VVDMul, + VAvg, + VRelu, + VTanh, + VSigm, + VSoftmax, + GetGlobal +}; + +struct CompiledCoreNode { + enum class Kind : uint8_t { + Op, + Loop + }; + + Kind kind = Kind::Op; + Operation* op = nullptr; + CompiledCoreOpKind opKind = CompiledCoreOpKind::Load; + std::optional weightIndex; + CompiledIndexExpr lowerBound; + CompiledIndexExpr upperBound; + CompiledIndexExpr step; + std::unique_ptr> loopBody; +}; + +static FailureOr classifyCompiledCoreOpKind(Operation& op) { + if (isa(op)) + return CompiledCoreOpKind::Load; + if (isa(op)) + return CompiledCoreOpKind::LoadBatch; + if (isa(op)) + return CompiledCoreOpKind::Store; + if (isa(op)) + return CompiledCoreOpKind::Lmv; + if (isa(op)) + return CompiledCoreOpKind::Receive; + if (isa(op)) + return CompiledCoreOpKind::ReceiveBatch; + if (isa(op)) + return CompiledCoreOpKind::ReceiveTensor; + if (isa(op)) + return CompiledCoreOpKind::ReceiveTensorBatch; + if (isa(op)) + return CompiledCoreOpKind::Send; + if (isa(op)) + return CompiledCoreOpKind::SendBatch; + if (isa(op)) + return CompiledCoreOpKind::SendTensor; + if (isa(op)) + return CompiledCoreOpKind::SendTensorBatch; + if (isa(op)) + return CompiledCoreOpKind::Concat; + if (isa(op)) + return CompiledCoreOpKind::Vmm; + if (isa(op)) + return CompiledCoreOpKind::Transpose; + if (isa(op)) + return CompiledCoreOpKind::VVAdd; + if (isa(op)) + return CompiledCoreOpKind::VVSub; + if (isa(op)) + return CompiledCoreOpKind::VVMul; + if (isa(op)) + return CompiledCoreOpKind::VVMax; + if (isa(op)) + return CompiledCoreOpKind::VVDMul; + if (isa(op)) + return CompiledCoreOpKind::VAvg; + if (isa(op)) + return CompiledCoreOpKind::VRelu; + if (isa(op)) + return CompiledCoreOpKind::VTanh; + if (isa(op)) + return CompiledCoreOpKind::VSigm; + if (isa(op)) + return CompiledCoreOpKind::VSoftmax; + if (isa(op)) + return CompiledCoreOpKind::GetGlobal; + return failure(); +} + +static LogicalResult compileCoreEmissionPlan(Block& block, + Operation* weightOwner, + llvm::SmallVectorImpl& plan) { + for (Operation& op : block) { + if (isa(op) || isCoreStaticAddressOp(&op)) + continue; + + if (auto loadOp = dyn_cast(op)) { + if (succeeded(compileIndexExpr(loadOp.getResult()))) + continue; + } + + if (auto forOp = dyn_cast(op)) { + auto lowerBound = compileIndexExpr(forOp.getLowerBound()); + auto upperBound = compileIndexExpr(forOp.getUpperBound()); + auto step = compileIndexExpr(forOp.getStep()); + if (failed(lowerBound) || failed(upperBound) || failed(step)) { + forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); + return failure(); + } + + CompiledCoreNode loopNode; + loopNode.kind = CompiledCoreNode::Kind::Loop; + loopNode.op = forOp.getOperation(); + loopNode.lowerBound = *lowerBound; + loopNode.upperBound = *upperBound; + loopNode.step = *step; + loopNode.loopBody = std::make_unique>(); + if (failed(compileCoreEmissionPlan(forOp.getRegion().front(), weightOwner, *loopNode.loopBody))) + return failure(); + plan.push_back(std::move(loopNode)); + continue; + } + + auto opKind = classifyCompiledCoreOpKind(op); + if (failed(opKind)) { + InFlightDiagnostic diag = op.emitError() << "unsupported codegen for op '" << op.getName().getStringRef() << "'"; + if (auto coreOp = op.getParentOfType()) + diag << " inside pim.core " << coreOp.getCoreId(); + else if (auto coreBatchOp = op.getParentOfType()) + diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount(); + return failure(); + } + + CompiledCoreNode opNode; + opNode.kind = CompiledCoreNode::Kind::Op; + opNode.op = &op; + opNode.opKind = *opKind; + if (auto vmmOp = dyn_cast(op)) { + auto weightIndex = onnx_mlir::resolveWeightIndex(weightOwner, vmmOp); + if (!weightIndex) + return failure(); + opNode.weightIndex = *weightIndex; + } + plan.push_back(std::move(opNode)); + } + return success(); +} + +static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl& plan, + PimCodeGen& coreCodeGen, + StaticValueKnowledge& knowledge, + size_t& processedOperations, + std::optional batchLane = std::nullopt, + std::optional batchLaneCount = std::nullopt) { + for (const CompiledCoreNode& node : plan) { + if (node.kind == CompiledCoreNode::Kind::Loop) { + auto lowerBound = node.lowerBound.evaluate(knowledge); + auto upperBound = node.upperBound.evaluate(knowledge); + auto step = node.step.evaluate(knowledge); + auto forOp = cast(node.op); + if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) { + forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); + return failure(); + } + + llvm::SmallVector iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end()); + for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) { + ScopedMapBindings indexBindings(knowledge.indexValues); + ScopedMapBindings aliasBindings(knowledge.aliases); + indexBindings.bind(forOp.getInductionVar(), inductionValue); + for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) + aliasBindings.bind(iterArg, iterValue); + + if (failed(executeCompiledCorePlan( + *node.loopBody, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount))) + return failure(); + + auto yieldOp = cast(forOp.getRegion().front().getTerminator()); + for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands())) + iterValues[index] = resolveLoopCarriedAlias(yieldedValue, knowledge); + } + continue; + } + + switch (node.opKind) { + case CompiledCoreOpKind::Load: + coreCodeGen.codeGenLoadOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::LoadBatch: + coreCodeGen.codeGenLoadBatchOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::Store: + coreCodeGen.codeGenStoreOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::Lmv: + coreCodeGen.codeGenLmvOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::Receive: + coreCodeGen.codeGenReceiveOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::ReceiveBatch: + if (!batchLane) + return failure(); + coreCodeGen.codeGenReceiveBatchOp(cast(node.op), *batchLane, knowledge); + break; + case CompiledCoreOpKind::ReceiveTensor: + coreCodeGen.codeGenReceiveTensorOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::ReceiveTensorBatch: + if (!batchLane || !batchLaneCount) + return failure(); + coreCodeGen.codeGenReceiveTensorBatchOp(cast(node.op), + getLaneChunkCoreIds(cast(node.op).getSourceCoreIds(), + *batchLaneCount, + *batchLane), + knowledge); + break; + case CompiledCoreOpKind::Send: + coreCodeGen.codeGenSendOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::SendBatch: + if (!batchLane) + return failure(); + coreCodeGen.codeGenSendBatchOp(cast(node.op), *batchLane, knowledge); + break; + case CompiledCoreOpKind::SendTensor: + coreCodeGen.codeGenSendTensorOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::SendTensorBatch: + if (!batchLane || !batchLaneCount) + return failure(); + coreCodeGen.codeGenSendTensorBatchOp(cast(node.op), + getLaneChunkCoreIds(cast(node.op).getTargetCoreIds(), + *batchLaneCount, + *batchLane), + knowledge); + break; + case CompiledCoreOpKind::Concat: + coreCodeGen.codeGenConcatOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::Vmm: + assert(node.weightIndex && "compiled VMM op must have cached weight index"); + coreCodeGen.codeGenMVMLikeOp( + *node.weightIndex, cast(node.op), true, knowledge); + break; + case CompiledCoreOpKind::Transpose: + coreCodeGen.codeGenTransposeOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VVAdd: + coreCodeGen.codeGenVVAddOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VVSub: + coreCodeGen.codeGenVVSubOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VVMul: + coreCodeGen.codeGenVVMulOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VVMax: + coreCodeGen.codeGenVVMaxOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VVDMul: + coreCodeGen.codeGenVVDMulOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VAvg: + coreCodeGen.codeGenVAvgOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VRelu: + coreCodeGen.codeGenVReluOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VTanh: + coreCodeGen.codeGenVTanhOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VSigm: + coreCodeGen.codeGenVSigmOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::VSoftmax: + coreCodeGen.codeGenVSoftmaxOp(cast(node.op), knowledge); + break; + case CompiledCoreOpKind::GetGlobal: + coreCodeGen.codeGetGlobalOp(cast(node.op), knowledge); + break; + } + processedOperations++; + } + return success(); +} + static SmallDenseMap collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) { SmallDenseMap materializedHostGlobals; @@ -791,19 +1204,21 @@ collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const Pim auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!targetGlobal || materializedHostGlobals.contains(targetGlobal)) return; - auto it = memory.memEntriesMap.find(getGlobalOp.getResult()); + auto it = memory.memEntriesMap.find(getMemoryValueKey(getGlobalOp.getResult())); if (it != memory.memEntriesMap.end()) materializedHostGlobals[targetGlobal] = it->second; }); return materializedHostGlobals; } -static void aliasMaterializedHostGlobals(ModuleOp moduleOp, - pim::PimCoreOp coreOp, +template +static void aliasMaterializedHostGlobals(CoreLikeOpTy coreLikeOp, + ModuleOp moduleOp, const SmallDenseMap& materializedHostGlobals, PimAcceleratorMemory& memory) { - coreOp.walk([&](memref::GetGlobalOp getGlobalOp) { - if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult())) + coreLikeOp.walk([&](memref::GetGlobalOp getGlobalOp) { + MemoryValueKey key = getMemoryValueKey(getGlobalOp.getResult()); + if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(key)) return; auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); @@ -812,7 +1227,7 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp, auto it = materializedHostGlobals.find(targetGlobal); if (it != materializedHostGlobals.end()) - memory.memEntriesMap[getGlobalOp.getResult()] = it->second; + memory.memEntriesMap[key] = it->second; }); } @@ -820,79 +1235,19 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp, /// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is /// fully resolved before the JSON instructions are emitted. /// Returns the number of emitted instructions, or -1 on failure. -static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { - auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional { - auto coreOp = vmmOp->getParentOfType(); - if (!coreOp) - return std::nullopt; - for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) - if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) - return weightIndex; - return std::nullopt; - }; +static int64_t codeGenCoreOps(Block& block, + PimCodeGen& coreCodeGen, + const StaticValueKnowledge& initialKnowledge, + Operation* weightOwner, + std::optional batchLane = std::nullopt, + std::optional batchLaneCount = std::nullopt) { + llvm::SmallVector plan; + if (failed(compileCoreEmissionPlan(block, weightOwner, plan))) + return -1; + size_t processedOperations = 0; - auto result = - walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { - if (auto loadOp = dyn_cast(op)) - coreCodeGen.codeGenLoadOp(loadOp, knowledge); - else if (auto loadBatchOp = dyn_cast(op)) - coreCodeGen.codeGenLoadBatchOp(loadBatchOp, knowledge); - else if (auto storeOp = dyn_cast(op)) - coreCodeGen.codeGenStoreOp(storeOp, knowledge); - else if (auto lmvOp = dyn_cast(op)) - coreCodeGen.codeGenLmvOp(lmvOp, knowledge); - else if (auto receiveOp = dyn_cast(op)) - coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); - else if (auto receiveTensorOp = dyn_cast(op)) - coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge); - else if (auto sendOp = dyn_cast(op)) - coreCodeGen.codeGenSendOp(sendOp, knowledge); - else if (auto sendTensorOp = dyn_cast(op)) - coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge); - else if (auto concatOp = dyn_cast(op)) - coreCodeGen.codeGenConcatOp(concatOp, knowledge); - else if (auto vmmOp = dyn_cast(op)) { - auto weightIndex = resolveWeightIndex(vmmOp); - if (!weightIndex) - return failure(); - coreCodeGen.codeGenMVMLikeOp(*weightIndex, vmmOp, true, knowledge); - } - else if (auto transposeOp = dyn_cast(op)) - coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); - else if (auto vvaddOp = dyn_cast(op)) - coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge); - else if (auto vvsubOp = dyn_cast(op)) - coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge); - else if (auto vvmulOp = dyn_cast(op)) - coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge); - else if (auto vvmaxOp = dyn_cast(op)) - coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge); - else if (auto vvdmulOp = dyn_cast(op)) - coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge); - else if (auto vavgOp = dyn_cast(op)) - coreCodeGen.codeGenVAvgOp(vavgOp, knowledge); - else if (auto vreluOp = dyn_cast(op)) - coreCodeGen.codeGenVReluOp(vreluOp, knowledge); - else if (auto vtanhOp = dyn_cast(op)) - coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge); - else if (auto vsigmOp = dyn_cast(op)) - coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); - else if (auto vsoftmaxOp = dyn_cast(op)) - coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); - else if (auto getGlobalOp = dyn_cast(op)) - coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); - else { - InFlightDiagnostic diag = op.emitError() - << "unsupported codegen for op '" << op.getName().getStringRef() << "'"; - if (auto coreOp = op.getParentOfType()) - diag << " inside pim.core " << coreOp.getCoreId(); - else if (auto coreBatchOp = op.getParentOfType()) - diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount(); - return failure(); - } - processedOperations++; - return success(); - }); + StaticValueKnowledge knowledge = initialKnowledge; + auto result = executeCompiledCorePlan(plan, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount); return failed(result) ? -1 : static_cast(processedOperations); } @@ -946,138 +1301,223 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: } } + SmallVector jobs; + SmallVector> batchJobIndices; for (Operation* op : coreLikeOps) { - auto emitCore = [&](pim::PimCoreOp coreOp, - bool temporaryCore, - MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes { - size_t originalCoreId = static_cast(coreOp.getCoreId()); - size_t coreId = emittedCoreIds.lookup(originalCoreId); - maxCoreId = std::max(maxCoreId, coreId); - - std::error_code errorCode; - auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".pim"; - raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; - } - - std::unique_ptr coreJsonStream; - if (pimEmitJson.getValue()) { - std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; - errorCode = std::error_code(); - coreJsonStream = std::make_unique(outputCoreJsonPath, errorCode); - if (errorCode) { - errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() - << '\n'; - return InvalidOutputFileAccess; - } - *coreJsonStream << '['; - } - - pim_binary::writeHeader(coreBinaryStream); - - PimCodeGen coreCodeGen(memory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds); - aliasMaterializedHostGlobals(moduleOp, coreOp, materializedHostGlobals, memory); - auto& deviceMemory = memory.getOrCreateDeviceMem(coreId); - deviceMemory.allocateCore(coreOp); - - int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); - if (processedOperations < 0) - return CompilerFailure; - assert(processedOperations > 0); - - if (reportRow) - *reportRow = deviceMemory.getReportRow(); - - pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount()); - coreBinaryStream.close(); - - if (coreJsonStream) { - coreJsonStream->seek(coreJsonStream->tell() - 1); - *coreJsonStream << ']'; - coreJsonStream->close(); - } - - auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); - if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { - errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; - return InvalidOutputFileAccess; - } - - auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId]; - json::Array xbarsPerGroup; - for (unsigned index : getUsedWeightIndices(coreOp)) { - if (index >= coreOp.getWeights().size()) { - coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreOp.getWeights().size() && "Weight index is out of range"); - } - mlir::Value weight = coreOp.getWeights()[index]; - xbarsPerGroup.push_back(index); - assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); - auto& fileName = mapWeightToFile[weight]; - if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, - coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { - errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " - << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") - << "\nError:" << error.message() << '\n'; - return InvalidOutputFileAccess; - } - } - - xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); - if (temporaryCore) - coreOp.walk([&memory](Operation* op) { memory.clean(op); }); - return CompilerSuccess; - }; - if (auto coreOp = dyn_cast(op)) { - MemoryReportRow coreRow; - if (auto err = emitCore(coreOp, false, &coreRow)) - return err; - memory.recordCoreReport(emittedCoreIds.lookup(static_cast(coreOp.getCoreId())), coreRow); + size_t originalCoreId = static_cast(coreOp.getCoreId()); + CoreEmissionJob job; + job.coreLikeOp = coreOp; + job.originalCoreId = originalCoreId; + job.emittedCoreId = emittedCoreIds.lookup(originalCoreId); + jobs.push_back(std::move(job)); continue; } auto coreBatchOp = cast(op); auto batchCoreIds = getBatchCoreIds(coreBatchOp); - SmallVector reportedCoreIds; - reportedCoreIds.reserve(batchCoreIds.size()); - MemoryReportRow batchRow; - std::optional batchPerCoreRow; llvm::DenseMap> lanesByCoreId; - SmallVector orderedOriginalCoreIds; - for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) { - size_t originalCoreId = static_cast(batchCoreIds[lane]); - auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId); - if (inserted) - orderedOriginalCoreIds.push_back(originalCoreId); - it->second.push_back(lane); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) + lanesByCoreId[static_cast(batchCoreIds[lane])].push_back(lane); + + SmallVector jobIndices; + SmallVector orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys()); + llvm::sort(orderedOriginalCoreIds, [&](size_t lhs, size_t rhs) { + return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); + }); + for (size_t originalCoreId : orderedOriginalCoreIds) { + CoreEmissionJob job; + job.coreLikeOp = coreBatchOp; + job.originalCoreId = originalCoreId; + job.emittedCoreId = emittedCoreIds.lookup(originalCoreId); + job.lanes = lanesByCoreId.lookup(originalCoreId); + job.batchReportId = nextBatchReportId; + jobIndices.push_back(jobs.size()); + jobs.push_back(std::move(job)); + } + batchJobIndices.push_back(std::move(jobIndices)); + ++nextBatchReportId; + } + + auto linkCoreWeights = [&](size_t originalCoreId, + size_t coreId, + ArrayRef usedIndices, + ValueRange weights, + Operation* weightOwner, + json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes { + auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); + if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { + errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; + return InvalidOutputFileAccess; } - for (size_t originalCoreId : orderedOriginalCoreIds) { - OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess; - if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) { - size_t coreId = emittedCoreIds.lookup(originalCoreId); - reportedCoreIds.push_back(static_cast(coreId)); - MemoryReportRow laneRow; - laneResult = emitCore(coreOp, true, &laneRow); - if (laneResult == CompilerSuccess) { - if (!batchPerCoreRow.has_value()) - batchPerCoreRow = laneRow; - batchRow = addMemoryReportRows(batchRow, laneRow); - } - return laneResult == CompilerSuccess ? success() : failure(); - }))) - return laneResult == CompilerSuccess ? CompilerFailure : laneResult; + auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId]; + for (unsigned index : usedIndices) { + if (index >= weights.size()) { + weightOwner->emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < weights.size() && "Weight index is out of range"); + } + mlir::Value weight = weights[index]; + xbarsPerGroup.push_back(index); + assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); + auto& fileName = mapWeightToFile[weight]; + if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, + coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { + errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " + << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") + << "\nError:" << error.message() << '\n'; + return InvalidOutputFileAccess; + } } - memory.recordBatchReport(nextBatchReportId++, + + return CompilerSuccess; + }; + + auto emitJob = [&](const CoreEmissionJob& job) -> CoreEmissionResult { + CoreEmissionResult result; + PimAcceleratorMemory jobMemory(memory.memEntriesMap, false); + + std::error_code errorCode; + auto outputCorePath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".pim"; + raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n'; + result.status = InvalidOutputFileAccess; + return result; + } + + std::unique_ptr coreJsonStream; + if (pimEmitJson.getValue()) { + std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".json"; + errorCode = std::error_code(); + coreJsonStream = std::make_unique(outputCoreJsonPath, errorCode); + if (errorCode) { + errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() + << '\n'; + result.status = InvalidOutputFileAccess; + return result; + } + *coreJsonStream << '['; + } + + pim_binary::writeHeader(coreBinaryStream); + PimCodeGen coreCodeGen(jobMemory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds); + + if (auto coreOp = dyn_cast(job.coreLikeOp)) { + aliasMaterializedHostGlobals(coreOp, moduleOp, materializedHostGlobals, jobMemory); + auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); + deviceMemory.allocateCore(coreOp); + + int64_t processedOperations = + codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation()); + if (processedOperations < 0) { + result.status = CompilerFailure; + return result; + } + assert(processedOperations > 0); + result.reportRow = deviceMemory.getReportRow(); + result.usedWeightIndices = getUsedWeightIndices(coreOp); + } + else { + auto coreBatchOp = cast(job.coreLikeOp); + aliasMaterializedHostGlobals(coreBatchOp, moduleOp, materializedHostGlobals, jobMemory); + auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); + result.usedWeightIndices = getUsedWeightIndices(coreBatchOp); + + for (unsigned lane : job.lanes) { + StaticValueKnowledge knowledge; + knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane; + for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i) + knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i]; + + deviceMemory.allocateCore(coreBatchOp, lane); + coreCodeGen.setBatchLane(lane); + int64_t processedOperations = codeGenCoreOps(coreBatchOp.getBody().front(), + coreCodeGen, + knowledge, + coreBatchOp.getOperation(), + lane, + static_cast(coreBatchOp.getLaneCount())); + if (processedOperations < 0) { + result.status = CompilerFailure; + return result; + } + assert(processedOperations > 0); + } + + result.reportRow = deviceMemory.getReportRow(); + } + + pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount()); + coreBinaryStream.close(); + + if (coreJsonStream) { + coreJsonStream->seek(coreJsonStream->tell() - 1); + *coreJsonStream << ']'; + coreJsonStream->close(); + } + + return result; + }; + + std::vector jobResults(jobs.size()); + mlir::parallelFor(moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { + jobResults[index] = emitJob(jobs[index]); + }); + + for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) + if (jobResults[jobIndex].status != CompilerSuccess) + return jobResults[jobIndex].status; + + for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { + const CoreEmissionJob& job = jobs[jobIndex]; + const CoreEmissionResult& result = jobResults[jobIndex]; + json::Array xbarsPerGroup; + + if (auto coreOp = dyn_cast(job.coreLikeOp)) { + if (auto err = linkCoreWeights( + job.originalCoreId, job.emittedCoreId, result.usedWeightIndices, coreOp.getWeights(), coreOp.getOperation(), xbarsPerGroup)) + return err; + xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); + memory.recordCoreReport(job.emittedCoreId, result.reportRow); + continue; + } + } + + for (const SmallVector& group : batchJobIndices) { + SmallVector reportedCoreIds; + MemoryReportRow batchRow; + std::optional batchPerCoreRow; + + for (size_t jobIndex : group) { + const CoreEmissionJob& job = jobs[jobIndex]; + const CoreEmissionResult& result = jobResults[jobIndex]; + auto coreBatchOp = cast(job.coreLikeOp); + json::Array xbarsPerGroup; + if (auto err = linkCoreWeights(job.originalCoreId, + job.emittedCoreId, + result.usedWeightIndices, + coreBatchOp.getWeights(), + coreBatchOp.getOperation(), + xbarsPerGroup)) + return err; + xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); + reportedCoreIds.push_back(static_cast(job.emittedCoreId)); + if (!batchPerCoreRow) + batchPerCoreRow = result.reportRow; + batchRow = addMemoryReportRows(batchRow, result.reportRow); + } + + uint64_t batchReportId = jobs[group.front()].batchReportId.value_or(0); + memory.recordBatchReport(batchReportId, reportedCoreIds, batchPerCoreRow.value_or(MemoryReportRow {}), batchRow.numAlloca, batchRow.sizeAlloca); } + maxCoreId = nextEmittedCoreId == 0 ? 0 : nextEmittedCoreId - 1; + memory.flushReport(); return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index bbe45dc..23db991 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -4,13 +4,16 @@ #include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_os_ostream.h" #include +#include #include #include "onnx-mlir/Compiler/OMCompilerTypes.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp" @@ -23,6 +26,13 @@ struct MemEntry { size_t size; }; +struct MemoryValueKey { + mlir::Value value; + std::optional lane; + + bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; } +}; + struct MemoryReportRow { uint64_t numAlloca = 0; uint64_t sizeAlloca = 0; @@ -50,33 +60,33 @@ struct MemoryReportEntry { }; class PimMemory { - llvm::SmallVector, 32> memEntries; - llvm::SmallDenseMap& globalMemEntriesMap; - llvm::SmallDenseMap ownedMemEntriesMap; + llvm::SmallVector, 32> memEntries; + llvm::SmallDenseMap& globalMemEntriesMap; + llvm::SmallDenseMap ownedMemEntriesMap; size_t minAlignment = 4; size_t firstAvailableAddress = 0; - MemEntry* gatherMemEntry(mlir::Value value); + MemEntry* gatherMemEntry(mlir::Value value, std::optional lane = std::nullopt); void allocateGatheredMemory(); - void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); + void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry); public: - PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) + PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) : globalMemEntriesMap(globalMemEntriesMap) {} void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); - void allocateCore(mlir::Operation* op); + void allocateCore(mlir::Operation* op, std::optional lane = std::nullopt); MemoryReportRow getReportRow() const; void remove(mlir::Value val); size_t getFirstAvailableAddress() const { return firstAvailableAddress; } - MemEntry getMemEntry(mlir::Value value) const; + MemEntry getMemEntry(const MemoryValueKey& key) const; }; class PimAcceleratorMemory { public: - llvm::SmallDenseMap memEntriesMap; + llvm::SmallDenseMap memEntriesMap; PimMemory hostMem; private: @@ -84,14 +94,21 @@ private: std::fstream fileReport; std::optional hostReportRow; llvm::SmallVector reportEntries; + mutable llvm::DenseMap compiledIndexExprs; + mutable llvm::DenseMap compiledAddressExprs; public: PimAcceleratorMemory() : hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {} + PimAcceleratorMemory(const llvm::SmallDenseMap& initialMemEntries, bool enableReport) + : memEntriesMap(initialMemEntries), hostMem(memEntriesMap), fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {} PimMemory& getOrCreateDeviceMem(size_t id); - size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; + size_t getValueAddress(mlir::Value value, + const StaticValueKnowledge& knowledge = {}, + std::optional lane = std::nullopt) const; + llvm::FailureOr getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; void reportHost(); void recordCoreReport(size_t coreId, const MemoryReportRow& row); void recordBatchReport(uint64_t batchId, @@ -103,15 +120,24 @@ public: void clean(mlir::Operation* op); }; +struct CoreEmissionJob { + mlir::Operation* coreLikeOp = nullptr; + size_t originalCoreId = 0; + size_t emittedCoreId = 0; + llvm::SmallVector lanes; + std::optional batchReportId; +}; + class PimCodeGen { PimAcceleratorMemory& memory; llvm::raw_fd_ostream& coreBinaryStream; llvm::raw_fd_ostream* coreJsonStream; const llvm::DenseMap& emittedCoreIds; + std::optional batchLane; mutable uint32_t emittedInstructionCount = 0; size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { - return memory.getValueAddress(value, knowledge); + return memory.getValueAddress(value, knowledge, batchLane); } size_t remapCoreId(size_t coreId) const; @@ -141,6 +167,10 @@ public: : memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {} uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; } + void setBatchLane(std::optional lane) { batchLane = lane; } + llvm::FailureOr indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { + return memory.getIndexValue(value, knowledge); + } void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const; @@ -151,6 +181,14 @@ public: void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const; + void codeGenReceiveBatchOp(pim::PimReceiveBatchOp receiveOp, unsigned lane, const StaticValueKnowledge& knowledge) const; + void codeGenReceiveTensorBatchOp(pim::PimReceiveTensorBatchOp receiveOp, + llvm::ArrayRef laneCoreIds, + const StaticValueKnowledge& knowledge) const; + void codeGenSendBatchOp(pim::PimSendBatchOp sendOp, unsigned lane, const StaticValueKnowledge& knowledge) const; + void codeGenSendTensorBatchOp(pim::PimSendTensorBatchOp sendOp, + llvm::ArrayRef laneCoreIds, + const StaticValueKnowledge& knowledge) const; void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const; template @@ -173,3 +211,24 @@ public: OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); } // namespace onnx_mlir + +namespace llvm { + +template <> +struct DenseMapInfo { + static onnx_mlir::MemoryValueKey getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), 0}; + } + + static onnx_mlir::MemoryValueKey getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), 0}; + } + + static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) { + return hash_combine(key.value, key.lane.value_or(std::numeric_limits::max())); + } + + static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; } +}; + +} // namespace llvm diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 971fa47..5bd7ebf 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -3,17 +3,19 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include +#include #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" -#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" @@ -126,30 +128,6 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value return view; } -SmallVector getUsedWeightIndices(Block& block) { - SmallVector indices; - auto coreOp = dyn_cast(block.getParentOp()); - auto addWeight = [&](mlir::Value weight) { - if (!coreOp) - return; - for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) { - if (coreOp.getWeightArgument(weightIndex) != weight) - continue; - if (!llvm::is_contained(indices, weightIndex)) - indices.push_back(weightIndex); - return; - } - }; - - block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); - llvm::sort(indices); - return indices; -} - -SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { - return getUsedWeightIndices(coreOp.getBody().front()); -} - SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { SmallVector coreLikeOps; for (Operation& op : funcOp.getBody().front()) @@ -171,86 +149,117 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { int64_t xbarSize = crossbarSize.getValue(); llvm::DenseMap> mapCoreWeightToFileName; llvm::DenseMap mapGlobalOpToFileName; + llvm::DenseMap mapWeightValueToFileName; SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); for (Operation* op : coreLikeOps) { - auto processCore = [&](pim::PimCoreOp coreOp) { - size_t coreId = static_cast(coreOp.getCoreId()); - for (unsigned index : getUsedWeightIndices(coreOp)) { - if (index >= coreOp.getWeights().size()) { - coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreOp.getWeights().size() && "Weight index is out of range"); - } - mlir::Value weight = coreOp.getWeights()[index]; + auto processWeight = [&](Operation* ownerOp, + mlir::Value weight, + size_t weightIndex, + size_t coreId) -> LogicalResult { + auto weightView = resolveDenseWeightView(moduleOp, weight); + if (failed(weightView)) { + ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex)); + assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); + } - auto weightView = resolveDenseWeightView(moduleOp, weight); - if (failed(weightView)) { - coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); - assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); - } + if (mapCoreWeightToFileName[coreId].contains(weight)) + return success(); - if (mapCoreWeightToFileName[coreId].contains(weight)) - continue; + if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) { + mapCoreWeightToFileName[coreId].insert({weight, weightFile->second}); + return success(); + } - auto getGlobalOp = weight.getDefiningOp(); - auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; - if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { - auto& fileName = mapGlobalOpToFileName[globalOp]; - mapCoreWeightToFileName[coreId].insert({weight, fileName}); - continue; - } + auto getGlobalOp = weight.getDefiningOp(); + auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; + if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { + auto& fileName = mapGlobalOpToFileName[globalOp]; + mapWeightValueToFileName[weight] = fileName; + mapCoreWeightToFileName[coreId].insert({weight, fileName}); + return success(); + } - DenseElementsAttr denseAttr = weightView->denseAttr; - ArrayRef shape = weightView->shape; - assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); - int64_t numRows = shape[0]; - int64_t numCols = shape[1]; - assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); + DenseElementsAttr denseAttr = weightView->denseAttr; + ArrayRef shape = weightView->shape; + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); + size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); - std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; - auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); - std::error_code errorCode; - raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; - assert(errorCode); - } + std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; + auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); + std::error_code errorCode; + raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + assert(errorCode); + } - uint64_t zero = 0; - for (int64_t row = 0; row < xbarSize; row++) { - for (int64_t col = 0; col < xbarSize; col++) { - if (row < numRows && col < numCols) { - int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; - APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); - uint64_t word = bits.getZExtValue(); - weightFileStream.write(reinterpret_cast(&word), elementByteWidth); - } - else { - weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); - } + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; + APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); } } - - weightFileStream.close(); - if (globalOp) - mapGlobalOpToFileName.insert({globalOp, newFileName}); - mapCoreWeightToFileName[coreId].insert({weight, newFileName}); } + + weightFileStream.close(); + if (globalOp) + mapGlobalOpToFileName.insert({globalOp, newFileName}); + mapWeightValueToFileName[weight] = newFileName; + mapCoreWeightToFileName[coreId].insert({weight, newFileName}); return success(); }; + auto processCoreLike = [&](auto coreLikeOp) { + auto usedIndices = getUsedWeightIndices(coreLikeOp); + for (unsigned index : usedIndices) { + if (index >= coreLikeOp.getWeights().size()) { + coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range"); + } + } + + if constexpr (std::is_same_v, pim::PimCoreOp>) { + size_t coreId = static_cast(coreLikeOp.getCoreId()); + for (unsigned index : usedIndices) + if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId))) + return failure(); + return success(); + } + else { + auto batchCoreIds = getBatchCoreIds(coreLikeOp); + SmallVector orderedCoreIds; + llvm::SmallSet seenCoreIds; + for (int32_t coreId : batchCoreIds) + if (seenCoreIds.insert(static_cast(coreId)).second) + orderedCoreIds.push_back(static_cast(coreId)); + + for (size_t coreId : orderedCoreIds) + for (unsigned index : usedIndices) + if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId))) + return failure(); + return success(); + } + }; + if (auto coreOp = dyn_cast(op)) { - (void) processCore(coreOp); + (void) processCoreLike(coreOp); continue; } - auto coreBatchOp = cast(op); - for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) - if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore))) - return mapCoreWeightToFileName; + (void) processCoreLike(cast(op)); } return mapCoreWeightToFileName; } diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp index db550be..6454a3b 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp @@ -8,6 +8,7 @@ #include #include "Common/IR/CompactAsmUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" @@ -47,12 +48,6 @@ struct CoalescingReportEntry { static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); } -static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { - auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); - assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); -} - static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) { llvm::SmallVector fields = { {"Number of candidates", std::to_string(row.numCandidates)}, diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index aa9cad4..2b44771 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -6,10 +6,10 @@ #include "llvm/ADT/STLExtras.h" +#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" -#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -173,6 +173,106 @@ static bool isSupportedCoreInstructionOp(Operation* op) { memref::GetGlobalOp>(op); } +static FailureOr getStaticByteSizedShapedType(Type type) { + auto shapedType = dyn_cast(type); + if (!shapedType || !shapedType.hasStaticShape()) + return failure(); + + int64_t elementBits = shapedType.getElementTypeBitWidth(); + if (elementBits <= 0 || elementBits % 8 != 0) + return failure(); + + return shapedType; +} + +static LogicalResult verifyBatchOpSemantics(Operation& op, + const StaticValueKnowledge& knowledge, + pim::CappedDiagnosticReporter& diagnostics) { + bool hasFailure = false; + auto reportFailure = [&](auto emitDiagnostic) { + diagnostics.report(&op, [&](Operation* illegalOp) { emitDiagnostic(illegalOp); }); + hasFailure = true; + }; + + if (auto memcpHdBatchOp = dyn_cast(op)) { + if (!isCodegenAddressableValue(memcpHdBatchOp.getHostSource(), knowledge)) { + reportFailure([](Operation* illegalOp) { + illegalOp->emitOpError("host operand #1 is not backed by contiguous addressable storage"); + }); + } + return success(!hasFailure); + } + + if (auto sendBatchOp = dyn_cast(op)) { + if (sendBatchOp.getTargetCoreIds().size() != static_cast(sendBatchOp->getParentOfType() + .getLaneCount())) { + reportFailure([](Operation* illegalOp) { + illegalOp->emitOpError("targetCoreIds size must match parent laneCount"); + }); + } + return success(!hasFailure); + } + + if (auto receiveBatchOp = dyn_cast(op)) { + if (receiveBatchOp.getSourceCoreIds().size() + != static_cast(receiveBatchOp->getParentOfType().getLaneCount())) { + reportFailure([](Operation* illegalOp) { + illegalOp->emitOpError("sourceCoreIds size must match parent laneCount"); + }); + } + return success(!hasFailure); + } + + auto verifyTensorBatchCommunication = [&](Value tensorValue, ArrayRef coreIds, StringRef kind) { + if (coreIds.empty()) { + reportFailure([&](Operation* illegalOp) { illegalOp->emitOpError() << kind << " must carry at least one chunk"; }); + return; + } + + auto parentBatchOp = op.getParentOfType(); + int32_t laneCount = parentBatchOp.getLaneCount(); + if (laneCount <= 0) { + reportFailure([&](Operation* illegalOp) { + illegalOp->emitOpError() << kind << " requires a positive parent laneCount"; + }); + return; + } + if (coreIds.size() % static_cast(laneCount) != 0) { + reportFailure([&](Operation* illegalOp) { + illegalOp->emitOpError() << kind << " core id count must be divisible by the parent laneCount"; + }); + return; + } + + auto shapedType = getStaticByteSizedShapedType(tensorValue.getType()); + if (failed(shapedType)) { + reportFailure([&](Operation* illegalOp) { + illegalOp->emitOpError() << kind << " requires a static shaped tensor or memref with byte-sized elements"; + }); + return; + } + + int64_t chunkCount = static_cast(coreIds.size()) / laneCount; + int64_t totalBytes = (*shapedType).getNumElements() * (*shapedType).getElementTypeBitWidth() / 8; + if (totalBytes % chunkCount != 0) { + reportFailure([&](Operation* illegalOp) { + illegalOp->emitOpError() << kind << " tensor byte size must be divisible by the chunk count per lane"; + }); + } + }; + + if (auto sendTensorBatchOp = dyn_cast(op)) + verifyTensorBatchCommunication(sendTensorBatchOp.getInput(), + sendTensorBatchOp.getTargetCoreIds(), + "send_tensor_batch"); + else if (auto receiveTensorBatchOp = dyn_cast(op)) + verifyTensorBatchCommunication(receiveTensorBatchOp.getOutput(), + receiveTensorBatchOp.getSourceCoreIds(), + "receive_tensor_batch"); + + return success(!hasFailure); +} + struct VerificationPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) @@ -204,16 +304,24 @@ struct VerificationPass : PassWrapper> for (Operation& op : funcOp.getBody().front().getOperations()) { if (auto coreOp = dyn_cast(&op)) { (void) verifyCoreWeights(moduleOp, coreOp, diagnostics); - (void) verifyCoreOperands(coreOp, diagnostics); + StaticValueKnowledge knowledge; + (void) verifyCoreLikeOperands(coreOp, knowledge, diagnostics); continue; } if (auto coreBatchOp = dyn_cast(&op)) { (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); - for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) - (void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { - return verifyCoreOperands(scalarCore, diagnostics); - }); + llvm::SmallVector lanes; + lanes.push_back(0); + if (coreBatchOp.getLaneCount() > 1) + lanes.push_back(static_cast(coreBatchOp.getLaneCount() - 1)); + for (unsigned lane : lanes) { + StaticValueKnowledge knowledge; + knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane; + for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i) + knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i]; + (void) verifyCoreLikeOperands(coreBatchOp, knowledge, diagnostics); + } continue; } @@ -299,10 +407,13 @@ private: return success(!hasFailure); } - template - static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) { - return walkPimCoreBlock( - coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { + template + static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp, + const StaticValueKnowledge& initialKnowledge, + pim::CappedDiagnosticReporter& diagnostics) { + return walkPimCoreBlockStructurally(coreLikeOp.getBody().front(), + initialKnowledge, + [&](Operation& op, const StaticValueKnowledge& knowledge) { bool hasFailure = false; if (!isSupportedCoreInstructionOp(&op)) { diagnostics.report(&op, [](Operation* illegalOp) { @@ -370,6 +481,9 @@ private: hasFailure = true; } } + + if (failed(verifyBatchOpSemantics(op, knowledge, diagnostics))) + hasFailure = true; return success(!hasFailure); }); }