automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:39:56 +02:00
parent 4bdaa57656
commit 874a2f53e6
23 changed files with 136 additions and 198 deletions
+25 -43
View File
@@ -111,39 +111,29 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) { static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) { switch (predicate) {
case mlir::arith::CmpIPredicate::eq: case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
return lhs == rhs; case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
case mlir::arith::CmpIPredicate::ne: case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
return lhs != rhs; case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
case mlir::arith::CmpIPredicate::slt: case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
return lhs < rhs; case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
case mlir::arith::CmpIPredicate::sle: case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
return lhs <= rhs; case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::sgt: case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
return lhs > rhs; case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::sge:
return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult:
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule:
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt:
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge:
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
} }
llvm_unreachable("unknown cmpi predicate"); llvm_unreachable("unknown cmpi predicate");
} }
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) { llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
const StaticValueKnowledge& knowledge) {
if (!expr.node) if (!expr.node)
return mlir::failure(); return mlir::failure();
switch (expr.node->kind) { switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant: case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
return expr.node->constant; case CompiledIndexExprNode::Kind::Symbol: {
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge); auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value); auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end()) if (iter != knowledge.indexValues.end())
@@ -158,19 +148,16 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
case CompiledIndexExprNode::Kind::RemUI: case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI: case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI: case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: { case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge); auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge); auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs)) if (failed(lhs) || failed(rhs))
return mlir::failure(); return mlir::failure();
switch (expr.node->kind) { switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add: case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
return *lhs + *rhs; case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Sub: case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul:
return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI: case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0) if (*rhs == 0)
return mlir::failure(); return mlir::failure();
@@ -191,10 +178,8 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
return *lhs % *rhs; return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI: case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs))); return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI: case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0; default: llvm_unreachable("unexpected binary compiled index kind");
default:
llvm_unreachable("unexpected binary compiled index kind");
} }
} }
case CompiledIndexExprNode::Kind::Select: { case CompiledIndexExprNode::Kind::Select: {
@@ -639,24 +624,21 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
staticStrides.reserve(subviewOp.getMixedStrides().size()); staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true; bool allStatic = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) { for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else else
allStatic = false; allStatic = false;
} for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size)) if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else else
allStatic = false; allStatic = false;
} for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride)) if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else else
allStatic = false; allStatic = false;
}
if (allStatic) { if (allStatic) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
@@ -796,8 +778,8 @@ llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge&
return evaluateCompiledIndexExpr(*this, knowledge); return evaluateCompiledIndexExpr(*this, knowledge);
} }
llvm::FailureOr<ResolvedContiguousAddress> llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const { std::optional<unsigned> lane) const {
(void) lane; (void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge); auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset)) if (failed(resolvedOffset))
+4 -3
View File
@@ -33,7 +33,8 @@ struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node; std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default; CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {} explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
: node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const; llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
}; };
@@ -68,8 +69,8 @@ struct CompiledAddressExpr {
mlir::Value base; mlir::Value base;
CompiledIndexExpr byteOffset; CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress> llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const; std::optional<unsigned> lane) const;
}; };
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
+1 -3
View File
@@ -1,5 +1,4 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -10,8 +9,7 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
} }
llvm::SmallVector<int32_t> llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds; llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount); laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
+1 -2
View File
@@ -9,7 +9,6 @@ namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp); llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t> llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
} // namespace onnx_mlir } // namespace onnx_mlir
+4 -5
View File
@@ -24,10 +24,9 @@ walkPimCoreBlock(mlir::Block& block,
/// Walks a `pim.core`-like body structurally for verification without /// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically /// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined. /// evaluable so address resolution remains well-defined.
mlir::LogicalResult mlir::LogicalResult walkPimCoreBlockStructurally(
walkPimCoreBlockStructurally(mlir::Block& block, mlir::Block& block,
const StaticValueKnowledge& knowledge, const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
callback);
} // namespace onnx_mlir } // namespace onnx_mlir
+1 -1
View File
@@ -4,9 +4,9 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <optional> #include <optional>
+56 -85
View File
@@ -199,22 +199,20 @@ MemoryReportRow PimMemory::getReportRow() const {
} }
void PimMemory::remove(mlir::Value val) { void PimMemory::remove(mlir::Value val) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) { for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
if (it->first.value == val) { if (it->first.value == val) {
auto eraseIt = it++; auto eraseIt = it++;
ownedMemEntriesMap.erase(eraseIt); ownedMemEntriesMap.erase(eraseIt);
} }
else else
++it; ++it;
} for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();)
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();) {
if (it->first.value == val) { if (it->first.value == val) {
auto eraseIt = it++; auto eraseIt = it++;
globalMemEntriesMap.erase(eraseIt); globalMemEntriesMap.erase(eraseIt);
} }
else else
++it; ++it;
}
} }
MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const { MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const {
@@ -275,7 +273,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
return iter->second.address + resolvedAddress->byteOffset; return iter->second.address + resolvedAddress->byteOffset;
} }
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) const { llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
const StaticValueKnowledge& knowledge) const {
value = resolveCachedAlias(value, knowledge); value = resolveCachedAlias(value, knowledge);
auto compiledIt = compiledIndexExprs.find(value); auto compiledIt = compiledIndexExprs.find(value);
if (compiledIt == compiledIndexExprs.end()) { if (compiledIt == compiledIndexExprs.end()) {
@@ -826,7 +825,8 @@ class ScopedMapBindings {
llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries; llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries;
public: public:
explicit ScopedMapBindings(MapTy& map) : map(map) {} explicit ScopedMapBindings(MapTy& map)
: map(map) {}
void bind(const KeyTy& key, const ValueTy& value) { void bind(const KeyTy& key, const ValueTy& value) {
auto it = map.find(key); auto it = map.find(key);
@@ -838,12 +838,11 @@ public:
} }
~ScopedMapBindings() { ~ScopedMapBindings() {
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it) { for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it)
if (it->second) if (it->second)
map[it->first] = *it->second; map[it->first] = *it->second;
else else
map.erase(it->first); map.erase(it->first);
}
} }
}; };
@@ -929,9 +928,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
return failure(); return failure();
} }
static LogicalResult compileCoreEmissionPlan(Block& block, static LogicalResult
Operation* weightOwner, compileCoreEmissionPlan(Block& block, Operation* weightOwner, llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
for (Operation& op : block) { for (Operation& op : block) {
if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op)) if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue; continue;
@@ -982,15 +980,14 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
return success(); return success();
} }
static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<CompiledCoreNode>& plan, static LogicalResult executeCompiledCorePlan(
PimCodeGen& coreCodeGen, const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
StaticValueKnowledge& knowledge, PimCodeGen& coreCodeGen,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, StaticValueKnowledge& knowledge,
const StaticValueKnowledge&)> llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
resolveWeightSlot, size_t& processedOperations,
size_t& processedOperations, std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLane = std::nullopt, std::optional<unsigned> batchLaneCount = std::nullopt) {
std::optional<unsigned> batchLaneCount = std::nullopt) {
for (const CompiledCoreNode& node : plan) { for (const CompiledCoreNode& node : plan) {
if (node.kind == CompiledCoreNode::Kind::Loop) { if (node.kind == CompiledCoreNode::Kind::Loop) {
auto lowerBound = node.lowerBound.evaluate(knowledge); auto lowerBound = node.lowerBound.evaluate(knowledge);
@@ -1010,8 +1007,13 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
aliasBindings.bind(iterArg, iterValue); aliasBindings.bind(iterArg, iterValue);
if (failed(executeCompiledCorePlan( if (failed(executeCompiledCorePlan(*node.loopBody,
*node.loopBody, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount))) coreCodeGen,
knowledge,
resolveWeightSlot,
processedOperations,
batchLane,
batchLaneCount)))
return failure(); return failure();
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator()); auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
@@ -1031,18 +1033,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
case CompiledCoreOpKind::Store: case CompiledCoreOpKind::Store:
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge); coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
break; break;
case CompiledCoreOpKind::Lmv: case CompiledCoreOpKind::Lmv: coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge); break;
coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge); case CompiledCoreOpKind::Receive: coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge); break;
break; case CompiledCoreOpKind::Send: coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Receive: case CompiledCoreOpKind::Concat: coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge); break;
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Send:
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Concat:
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Vmm: case CompiledCoreOpKind::Vmm:
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot)) if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge); coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
@@ -1052,33 +1046,15 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
case CompiledCoreOpKind::Transpose: case CompiledCoreOpKind::Transpose:
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge); coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
break; break;
case CompiledCoreOpKind::VVAdd: case CompiledCoreOpKind::VVAdd: coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge); break;
coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge); case CompiledCoreOpKind::VVSub: coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge); break;
break; case CompiledCoreOpKind::VVMul: coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVSub: case CompiledCoreOpKind::VVMax: coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge); break;
coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge); case CompiledCoreOpKind::VVDMul: coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge); break;
break; case CompiledCoreOpKind::VAvg: coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMul: case CompiledCoreOpKind::VRelu: coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge); break;
coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge); case CompiledCoreOpKind::VTanh: coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge); break;
break; case CompiledCoreOpKind::VSigm: coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMax:
coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVDMul:
coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VAvg:
coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VRelu:
coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VTanh:
coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VSigm:
coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VSoftmax: case CompiledCoreOpKind::VSoftmax:
coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge); coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge);
break; break;
@@ -1131,23 +1107,22 @@ static void aliasMaterializedHostGlobals(CoreLikeOpTy coreLikeOp,
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is /// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted. /// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure. /// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block, static int64_t codeGenCoreOps(
PimCodeGen& coreCodeGen, Block& block,
const StaticValueKnowledge& initialKnowledge, PimCodeGen& coreCodeGen,
Operation* weightOwner, const StaticValueKnowledge& initialKnowledge,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, Operation* weightOwner,
const StaticValueKnowledge&)> llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
resolveWeightSlot, std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLane = std::nullopt, std::optional<unsigned> batchLaneCount = std::nullopt) {
std::optional<unsigned> batchLaneCount = std::nullopt) {
llvm::SmallVector<CompiledCoreNode, 32> plan; llvm::SmallVector<CompiledCoreNode, 32> plan;
if (failed(compileCoreEmissionPlan(block, weightOwner, plan))) if (failed(compileCoreEmissionPlan(block, weightOwner, plan)))
return -1; return -1;
size_t processedOperations = 0; size_t processedOperations = 0;
StaticValueKnowledge knowledge = initialKnowledge; StaticValueKnowledge knowledge = initialKnowledge;
auto result = auto result = executeCompiledCorePlan(
executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount); plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
return failed(result) ? -1 : static_cast<int64_t>(processedOperations); return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
} }
@@ -1219,9 +1194,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
SmallVector<size_t> jobIndices; SmallVector<size_t> jobIndices;
SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys()); SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys());
llvm::sort(orderedOriginalCoreIds, [&](size_t lhs, size_t rhs) { llvm::sort(orderedOriginalCoreIds,
return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); [&](size_t lhs, size_t rhs) { return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); });
});
for (size_t originalCoreId : orderedOriginalCoreIds) { for (size_t originalCoreId : orderedOriginalCoreIds) {
CoreEmissionJob job; CoreEmissionJob job;
job.coreLikeOp = coreBatchOp; job.coreLikeOp = coreBatchOp;
@@ -1236,9 +1210,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
++nextBatchReportId; ++nextBatchReportId;
} }
auto linkCoreWeights = [&](size_t coreId, auto linkCoreWeights =
ArrayRef<std::string> weightFiles, [&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
@@ -1250,8 +1223,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) { coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message()
<< "\nError:" << error.message() << '\n'; << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
} }
@@ -1294,8 +1267,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
errorCode = std::error_code(); errorCode = std::error_code();
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode); coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
if (errorCode) { if (errorCode) {
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() << '\n';
<< '\n';
result.status = InvalidOutputFileAccess; result.status = InvalidOutputFileAccess;
return result; return result;
} }
@@ -1364,9 +1336,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
}; };
std::vector<CoreEmissionResult> jobResults(jobs.size()); std::vector<CoreEmissionResult> jobResults(jobs.size());
mlir::parallelFor(moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { mlir::parallelFor(
jobResults[index] = emitJob(jobs[index]); moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
});
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess) if (jobResults[jobIndex].status != CompilerSuccess)
+5 -7
View File
@@ -101,7 +101,9 @@ public:
PimAcceleratorMemory() PimAcceleratorMemory()
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {} : hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport) PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries), hostMem(memEntriesMap), fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {} : memEntriesMap(initialMemEntries),
hostMem(memEntriesMap),
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id); PimMemory& getOrCreateDeviceMem(size_t id);
@@ -206,13 +208,9 @@ namespace llvm {
template <> template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> { struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() { static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0};
}
static onnx_mlir::MemoryValueKey getTombstoneKey() { static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0};
}
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) { static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max())); return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
+1 -3
View File
@@ -16,9 +16,7 @@ using namespace llvm;
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {} // namespace
} // namespace
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) { createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
@@ -198,7 +198,6 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
return nullptr; return nullptr;
} }
static std::optional<CompileTimeSource> static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) { getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op) if (!op)
@@ -217,7 +216,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
chainLength += 1; chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op)) if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt; return hasConstantIndices(extractOp)
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (!isStaticTensorResult(op)) if (!isStaticTensorResult(op))
return std::nullopt; return std::nullopt;
@@ -232,8 +233,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength); return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength) return hasStaticUnitStrides(extractSliceOp)
: std::nullopt; ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op)) if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength); return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
@@ -252,9 +254,8 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
res = partialRes; res = partialRes;
continue; continue;
} }
if(res->chainLength < partialRes->chainLength){ if (res->chainLength < partialRes->chainLength)
res = partialRes; res = partialRes;
}
} }
return res; return res;
} }
@@ -264,8 +265,7 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
} // namespace } // namespace
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited; llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited); return getCompileTimeSourceImpl(op, visited);
} }
@@ -2,9 +2,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@@ -143,13 +143,12 @@ static Value createGemmBatchKOffset(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
} }
static Value createGemmBatchHOffset( static Value createGemmBatchHOffset(Value lane,
Value lane, int64_t numOutRows,
int64_t numOutRows, int64_t numKSlices,
int64_t numKSlices, int64_t numOutHSlices,
int64_t numOutHSlices, ConversionPatternRewriter& rewriter,
ConversionPatternRewriter& rewriter, Location loc) {
Location loc) {
if (numOutHSlices == 1) if (numOutHSlices == 1)
return createIndexConstant(rewriter, 0); return createIndexConstant(rewriter, 0);
@@ -9,8 +9,8 @@
#include <numeric> #include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -6,8 +6,8 @@
#include <algorithm> #include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -4,8 +4,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -2,8 +2,8 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -171,7 +171,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
markOpToRemove(receiveOp); markOpToRemove(receiveOp);
continue; continue;
} }
} }
if (computeOp.getNumResults() != yieldOp.getNumOperands()) if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -606,7 +606,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
markOpToRemove(receiveOp); markOpToRemove(receiveOp);
return; return;
} }
}; };
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
+1 -1
View File
@@ -1,5 +1,5 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include <string> #include <string>
@@ -122,9 +122,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
llvm::errs() << "[merge-profile] " << phaseName << " counts:" llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount << " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount << " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount << " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
<< " scf_for=" << counts.scfForCount << "\n";
} }
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
@@ -514,7 +513,8 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
llvm::append_range(coreIds, coreIdsAttr.asArrayRef()); llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds}); collectedData.push_back(
{nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
totalComputeOps += 1; totalComputeOps += 1;
totalLogicalComputes += logicalCount; totalLogicalComputes += logicalCount;
totalBatchComputeOps += 1; totalBatchComputeOps += 1;
@@ -206,10 +206,8 @@ static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg); return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
} }
static FailureOr<int64_t> evaluateIndexLike(Value value, static FailureOr<int64_t>
const DenseMap<Value, int64_t>& bindings, evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg) {
std::optional<uint32_t> lane,
Value laneArg) {
if (lane && value == laneArg) if (lane && value == laneArg)
return *lane; return *lane;
if (auto it = bindings.find(value); it != bindings.end()) if (auto it = bindings.find(value); it != bindings.end())
@@ -260,11 +258,10 @@ static FailureOr<int64_t> evaluateIndexLike(Value value,
return evaluateAffineExpr(map.getResult(0), dims, symbols); return evaluateAffineExpr(map.getResult(0), dims, symbols);
} }
static FailureOr<SmallVector<int64_t, 4>> static FailureOr<SmallVector<int64_t, 4>> evaluateIndexList(ArrayRef<OpFoldResult> values,
evaluateIndexList(ArrayRef<OpFoldResult> values, const DenseMap<Value, int64_t>& bindings,
const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane,
std::optional<uint32_t> lane, Value laneArg) {
Value laneArg) {
SmallVector<int64_t, 4> result; SmallVector<int64_t, 4> result;
result.reserve(values.size()); result.reserve(values.size());
for (OpFoldResult value : values) { for (OpFoldResult value : values) {
@@ -308,12 +305,11 @@ static CrossbarWeight completeCrossbarWeight(Value root,
return weight; return weight;
} }
static FailureOr<CrossbarWeight> static FailureOr<CrossbarWeight> getStaticCrossbarWeight(Operation* owner,
getStaticCrossbarWeight(Operation* owner, Value value,
Value value, const DenseMap<Value, int64_t>& bindings,
const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane,
std::optional<uint32_t> lane, Value laneArg) {
Value laneArg) {
if (auto extract = value.getDefiningOp<tensor::ExtractSliceOp>()) { if (auto extract = value.getDefiningOp<tensor::ExtractSliceOp>()) {
FailureOr<CrossbarWeight> sourceWeight = FailureOr<CrossbarWeight> sourceWeight =
getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg); getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg);
@@ -19,7 +19,6 @@ using CPU = int;
using Cost = unsigned long long; using Cost = unsigned long long;
using Time = unsigned long long; using Time = unsigned long long;
template <typename T> template <typename T>
inline T checkedAdd(T lhs, T rhs) { inline T checkedAdd(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types"); static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
+2 -3
View File
@@ -327,9 +327,8 @@ private:
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp, static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
const StaticValueKnowledge& initialKnowledge, const StaticValueKnowledge& initialKnowledge,
pim::CappedDiagnosticReporter& diagnostics) { pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlockStructurally(coreLikeOp.getBody().front(), return walkPimCoreBlockStructurally(
initialKnowledge, coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
[&](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false; bool hasFailure = false;
if (!isSupportedCoreInstructionOp(&op)) { if (!isSupportedCoreInstructionOp(&op)) {
diagnostics.report(&op, [](Operation* illegalOp) { diagnostics.report(&op, [](Operation* illegalOp) {