This commit is contained in:
@@ -111,38 +111,28 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
|
|||||||
|
|
||||||
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
||||||
switch (predicate) {
|
switch (predicate) {
|
||||||
case mlir::arith::CmpIPredicate::eq:
|
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
|
||||||
return lhs == rhs;
|
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
|
||||||
case mlir::arith::CmpIPredicate::ne:
|
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
|
||||||
return lhs != rhs;
|
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
|
||||||
case mlir::arith::CmpIPredicate::slt:
|
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
|
||||||
return lhs < rhs;
|
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
|
||||||
case mlir::arith::CmpIPredicate::sle:
|
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||||
return lhs <= rhs;
|
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||||
case mlir::arith::CmpIPredicate::sgt:
|
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||||
return lhs > rhs;
|
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||||
case mlir::arith::CmpIPredicate::sge:
|
|
||||||
return lhs >= rhs;
|
|
||||||
case mlir::arith::CmpIPredicate::ult:
|
|
||||||
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
|
||||||
case mlir::arith::CmpIPredicate::ule:
|
|
||||||
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
|
||||||
case mlir::arith::CmpIPredicate::ugt:
|
|
||||||
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
|
||||||
case mlir::arith::CmpIPredicate::uge:
|
|
||||||
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_unreachable("unknown cmpi predicate");
|
llvm_unreachable("unknown cmpi predicate");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
|
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
|
||||||
|
const StaticValueKnowledge& knowledge) {
|
||||||
if (!expr.node)
|
if (!expr.node)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
switch (expr.node->kind) {
|
switch (expr.node->kind) {
|
||||||
case CompiledIndexExprNode::Kind::Constant:
|
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
|
||||||
return expr.node->constant;
|
|
||||||
case CompiledIndexExprNode::Kind::Symbol: {
|
case CompiledIndexExprNode::Kind::Symbol: {
|
||||||
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
||||||
auto iter = knowledge.indexValues.find(value);
|
auto iter = knowledge.indexValues.find(value);
|
||||||
@@ -165,12 +155,9 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
switch (expr.node->kind) {
|
switch (expr.node->kind) {
|
||||||
case CompiledIndexExprNode::Kind::Add:
|
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
|
||||||
return *lhs + *rhs;
|
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
|
||||||
case CompiledIndexExprNode::Kind::Sub:
|
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
|
||||||
return *lhs - *rhs;
|
|
||||||
case CompiledIndexExprNode::Kind::Mul:
|
|
||||||
return *lhs * *rhs;
|
|
||||||
case CompiledIndexExprNode::Kind::DivUI:
|
case CompiledIndexExprNode::Kind::DivUI:
|
||||||
if (*rhs == 0)
|
if (*rhs == 0)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@@ -191,10 +178,8 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
|
|||||||
return *lhs % *rhs;
|
return *lhs % *rhs;
|
||||||
case CompiledIndexExprNode::Kind::MinUI:
|
case CompiledIndexExprNode::Kind::MinUI:
|
||||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||||
case CompiledIndexExprNode::Kind::CmpI:
|
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||||
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
default: llvm_unreachable("unexpected binary compiled index kind");
|
||||||
default:
|
|
||||||
llvm_unreachable("unexpected binary compiled index kind");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case CompiledIndexExprNode::Kind::Select: {
|
case CompiledIndexExprNode::Kind::Select: {
|
||||||
@@ -639,24 +624,21 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||||
bool allStatic = true;
|
bool allStatic = true;
|
||||||
|
|
||||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
else
|
||||||
allStatic = false;
|
allStatic = false;
|
||||||
}
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
|
||||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
||||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
else
|
||||||
allStatic = false;
|
allStatic = false;
|
||||||
}
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
|
||||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
||||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
else
|
||||||
allStatic = false;
|
allStatic = false;
|
||||||
}
|
|
||||||
|
|
||||||
if (allStatic) {
|
if (allStatic) {
|
||||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||||
@@ -796,8 +778,8 @@ llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge&
|
|||||||
return evaluateCompiledIndexExpr(*this, knowledge);
|
return evaluateCompiledIndexExpr(*this, knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress>
|
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
|
||||||
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
|
std::optional<unsigned> lane) const {
|
||||||
(void) lane;
|
(void) lane;
|
||||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||||
if (failed(resolvedOffset))
|
if (failed(resolvedOffset))
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ struct CompiledIndexExpr {
|
|||||||
std::shared_ptr<CompiledIndexExprNode> node;
|
std::shared_ptr<CompiledIndexExprNode> node;
|
||||||
|
|
||||||
CompiledIndexExpr() = default;
|
CompiledIndexExpr() = default;
|
||||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
|
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
|
||||||
|
: node(std::move(node)) {}
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
||||||
};
|
};
|
||||||
@@ -68,8 +69,8 @@ struct CompiledAddressExpr {
|
|||||||
mlir::Value base;
|
mlir::Value base;
|
||||||
CompiledIndexExpr byteOffset;
|
CompiledIndexExpr byteOffset;
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress>
|
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
|
||||||
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
|
std::optional<unsigned> lane) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -10,8 +9,7 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
|||||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int32_t>
|
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
|
||||||
llvm::SmallVector<int32_t> laneCoreIds;
|
llvm::SmallVector<int32_t> laneCoreIds;
|
||||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||||
|
|
||||||
llvm::SmallVector<int32_t>
|
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||||
getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ walkPimCoreBlock(mlir::Block& block,
|
|||||||
/// Walks a `pim.core`-like body structurally for verification without
|
/// Walks a `pim.core`-like body structurally for verification without
|
||||||
/// enumerating full loop trip counts. Loop bounds must still be statically
|
/// enumerating full loop trip counts. Loop bounds must still be statically
|
||||||
/// evaluable so address resolution remains well-defined.
|
/// evaluable so address resolution remains well-defined.
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||||
walkPimCoreBlockStructurally(mlir::Block& block,
|
mlir::Block& block,
|
||||||
const StaticValueKnowledge& knowledge,
|
const StaticValueKnowledge& knowledge,
|
||||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)>
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||||
callback);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|||||||
@@ -199,15 +199,14 @@ MemoryReportRow PimMemory::getReportRow() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::remove(mlir::Value val) {
|
void PimMemory::remove(mlir::Value val) {
|
||||||
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) {
|
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
|
||||||
if (it->first.value == val) {
|
if (it->first.value == val) {
|
||||||
auto eraseIt = it++;
|
auto eraseIt = it++;
|
||||||
ownedMemEntriesMap.erase(eraseIt);
|
ownedMemEntriesMap.erase(eraseIt);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
++it;
|
++it;
|
||||||
}
|
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();)
|
||||||
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();) {
|
|
||||||
if (it->first.value == val) {
|
if (it->first.value == val) {
|
||||||
auto eraseIt = it++;
|
auto eraseIt = it++;
|
||||||
globalMemEntriesMap.erase(eraseIt);
|
globalMemEntriesMap.erase(eraseIt);
|
||||||
@@ -215,7 +214,6 @@ void PimMemory::remove(mlir::Value val) {
|
|||||||
else
|
else
|
||||||
++it;
|
++it;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const {
|
MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const {
|
||||||
auto iter = globalMemEntriesMap.find(key);
|
auto iter = globalMemEntriesMap.find(key);
|
||||||
@@ -275,7 +273,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
|||||||
return iter->second.address + resolvedAddress->byteOffset;
|
return iter->second.address + resolvedAddress->byteOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge) const {
|
||||||
value = resolveCachedAlias(value, knowledge);
|
value = resolveCachedAlias(value, knowledge);
|
||||||
auto compiledIt = compiledIndexExprs.find(value);
|
auto compiledIt = compiledIndexExprs.find(value);
|
||||||
if (compiledIt == compiledIndexExprs.end()) {
|
if (compiledIt == compiledIndexExprs.end()) {
|
||||||
@@ -826,7 +825,8 @@ class ScopedMapBindings {
|
|||||||
llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries;
|
llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ScopedMapBindings(MapTy& map) : map(map) {}
|
explicit ScopedMapBindings(MapTy& map)
|
||||||
|
: map(map) {}
|
||||||
|
|
||||||
void bind(const KeyTy& key, const ValueTy& value) {
|
void bind(const KeyTy& key, const ValueTy& value) {
|
||||||
auto it = map.find(key);
|
auto it = map.find(key);
|
||||||
@@ -838,13 +838,12 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
~ScopedMapBindings() {
|
~ScopedMapBindings() {
|
||||||
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it) {
|
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it)
|
||||||
if (it->second)
|
if (it->second)
|
||||||
map[it->first] = *it->second;
|
map[it->first] = *it->second;
|
||||||
else
|
else
|
||||||
map.erase(it->first);
|
map.erase(it->first);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class CompiledCoreOpKind : uint8_t {
|
enum class CompiledCoreOpKind : uint8_t {
|
||||||
@@ -929,9 +928,8 @@ static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult compileCoreEmissionPlan(Block& block,
|
static LogicalResult
|
||||||
Operation* weightOwner,
|
compileCoreEmissionPlan(Block& block, Operation* weightOwner, llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
|
||||||
llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
|
|
||||||
for (Operation& op : block) {
|
for (Operation& op : block) {
|
||||||
if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||||
continue;
|
continue;
|
||||||
@@ -982,12 +980,11 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
|
static LogicalResult executeCompiledCorePlan(
|
||||||
|
const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
|
||||||
PimCodeGen& coreCodeGen,
|
PimCodeGen& coreCodeGen,
|
||||||
StaticValueKnowledge& knowledge,
|
StaticValueKnowledge& knowledge,
|
||||||
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
|
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
|
||||||
const StaticValueKnowledge&)>
|
|
||||||
resolveWeightSlot,
|
|
||||||
size_t& processedOperations,
|
size_t& processedOperations,
|
||||||
std::optional<unsigned> batchLane = std::nullopt,
|
std::optional<unsigned> batchLane = std::nullopt,
|
||||||
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
||||||
@@ -1010,8 +1007,13 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
|||||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||||
aliasBindings.bind(iterArg, iterValue);
|
aliasBindings.bind(iterArg, iterValue);
|
||||||
|
|
||||||
if (failed(executeCompiledCorePlan(
|
if (failed(executeCompiledCorePlan(*node.loopBody,
|
||||||
*node.loopBody, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount)))
|
coreCodeGen,
|
||||||
|
knowledge,
|
||||||
|
resolveWeightSlot,
|
||||||
|
processedOperations,
|
||||||
|
batchLane,
|
||||||
|
batchLaneCount)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
|
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
|
||||||
@@ -1031,18 +1033,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
|||||||
case CompiledCoreOpKind::Store:
|
case CompiledCoreOpKind::Store:
|
||||||
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
|
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
case CompiledCoreOpKind::Lmv:
|
case CompiledCoreOpKind::Lmv: coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge); break;
|
||||||
coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge);
|
case CompiledCoreOpKind::Receive: coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge); break;
|
||||||
break;
|
case CompiledCoreOpKind::Send: coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge); break;
|
||||||
case CompiledCoreOpKind::Receive:
|
case CompiledCoreOpKind::Concat: coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge); break;
|
||||||
coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Send:
|
|
||||||
coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Concat:
|
|
||||||
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Vmm:
|
case CompiledCoreOpKind::Vmm:
|
||||||
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
|
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
|
||||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
|
||||||
@@ -1052,33 +1046,15 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
|||||||
case CompiledCoreOpKind::Transpose:
|
case CompiledCoreOpKind::Transpose:
|
||||||
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
|
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
case CompiledCoreOpKind::VVAdd:
|
case CompiledCoreOpKind::VVAdd: coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge); break;
|
||||||
coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge);
|
case CompiledCoreOpKind::VVSub: coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge); break;
|
||||||
break;
|
case CompiledCoreOpKind::VVMul: coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge); break;
|
||||||
case CompiledCoreOpKind::VVSub:
|
case CompiledCoreOpKind::VVMax: coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge); break;
|
||||||
coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge);
|
case CompiledCoreOpKind::VVDMul: coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge); break;
|
||||||
break;
|
case CompiledCoreOpKind::VAvg: coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge); break;
|
||||||
case CompiledCoreOpKind::VVMul:
|
case CompiledCoreOpKind::VRelu: coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge); break;
|
||||||
coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge);
|
case CompiledCoreOpKind::VTanh: coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge); break;
|
||||||
break;
|
case CompiledCoreOpKind::VSigm: coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge); break;
|
||||||
case CompiledCoreOpKind::VVMax:
|
|
||||||
coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VVDMul:
|
|
||||||
coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VAvg:
|
|
||||||
coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VRelu:
|
|
||||||
coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VTanh:
|
|
||||||
coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VSigm:
|
|
||||||
coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::VSoftmax:
|
case CompiledCoreOpKind::VSoftmax:
|
||||||
coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge);
|
coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
@@ -1131,13 +1107,12 @@ static void aliasMaterializedHostGlobals(CoreLikeOpTy coreLikeOp,
|
|||||||
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
|
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
|
||||||
/// fully resolved before the JSON instructions are emitted.
|
/// fully resolved before the JSON instructions are emitted.
|
||||||
/// Returns the number of emitted instructions, or -1 on failure.
|
/// Returns the number of emitted instructions, or -1 on failure.
|
||||||
static int64_t codeGenCoreOps(Block& block,
|
static int64_t codeGenCoreOps(
|
||||||
|
Block& block,
|
||||||
PimCodeGen& coreCodeGen,
|
PimCodeGen& coreCodeGen,
|
||||||
const StaticValueKnowledge& initialKnowledge,
|
const StaticValueKnowledge& initialKnowledge,
|
||||||
Operation* weightOwner,
|
Operation* weightOwner,
|
||||||
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
|
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
|
||||||
const StaticValueKnowledge&)>
|
|
||||||
resolveWeightSlot,
|
|
||||||
std::optional<unsigned> batchLane = std::nullopt,
|
std::optional<unsigned> batchLane = std::nullopt,
|
||||||
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
||||||
llvm::SmallVector<CompiledCoreNode, 32> plan;
|
llvm::SmallVector<CompiledCoreNode, 32> plan;
|
||||||
@@ -1146,8 +1121,8 @@ static int64_t codeGenCoreOps(Block& block,
|
|||||||
|
|
||||||
size_t processedOperations = 0;
|
size_t processedOperations = 0;
|
||||||
StaticValueKnowledge knowledge = initialKnowledge;
|
StaticValueKnowledge knowledge = initialKnowledge;
|
||||||
auto result =
|
auto result = executeCompiledCorePlan(
|
||||||
executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
|
plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
|
||||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1219,9 +1194,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
|
|
||||||
SmallVector<size_t> jobIndices;
|
SmallVector<size_t> jobIndices;
|
||||||
SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys());
|
SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys());
|
||||||
llvm::sort(orderedOriginalCoreIds, [&](size_t lhs, size_t rhs) {
|
llvm::sort(orderedOriginalCoreIds,
|
||||||
return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs);
|
[&](size_t lhs, size_t rhs) { return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); });
|
||||||
});
|
|
||||||
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||||
CoreEmissionJob job;
|
CoreEmissionJob job;
|
||||||
job.coreLikeOp = coreBatchOp;
|
job.coreLikeOp = coreBatchOp;
|
||||||
@@ -1236,9 +1210,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
++nextBatchReportId;
|
++nextBatchReportId;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto linkCoreWeights = [&](size_t coreId,
|
auto linkCoreWeights =
|
||||||
ArrayRef<std::string> weightFiles,
|
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
|
||||||
json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
|
|
||||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||||
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
||||||
@@ -1250,8 +1223,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||||
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
|
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
|
||||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")
|
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message()
|
||||||
<< "\nError:" << error.message() << '\n';
|
<< '\n';
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1294,8 +1267,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
errorCode = std::error_code();
|
errorCode = std::error_code();
|
||||||
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
|
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
|
||||||
if (errorCode) {
|
if (errorCode) {
|
||||||
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
|
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() << '\n';
|
||||||
<< '\n';
|
|
||||||
result.status = InvalidOutputFileAccess;
|
result.status = InvalidOutputFileAccess;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -1364,9 +1336,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::vector<CoreEmissionResult> jobResults(jobs.size());
|
std::vector<CoreEmissionResult> jobResults(jobs.size());
|
||||||
mlir::parallelFor(moduleOp.getContext(), 0, jobs.size(), [&](size_t index) {
|
mlir::parallelFor(
|
||||||
jobResults[index] = emitJob(jobs[index]);
|
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
|
||||||
});
|
|
||||||
|
|
||||||
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
|
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
|
||||||
if (jobResults[jobIndex].status != CompilerSuccess)
|
if (jobResults[jobIndex].status != CompilerSuccess)
|
||||||
|
|||||||
@@ -101,7 +101,9 @@ public:
|
|||||||
PimAcceleratorMemory()
|
PimAcceleratorMemory()
|
||||||
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
||||||
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
|
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
|
||||||
: memEntriesMap(initialMemEntries), hostMem(memEntriesMap), fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
|
: memEntriesMap(initialMemEntries),
|
||||||
|
hostMem(memEntriesMap),
|
||||||
|
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
|
||||||
|
|
||||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||||
|
|
||||||
@@ -206,13 +208,9 @@ namespace llvm {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
|
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
|
||||||
static onnx_mlir::MemoryValueKey getEmptyKey() {
|
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
|
||||||
return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
static onnx_mlir::MemoryValueKey getTombstoneKey() {
|
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
|
||||||
return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
|
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
|
||||||
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
|
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ using namespace llvm;
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {} // namespace
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||||
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
||||||
|
|||||||
@@ -198,7 +198,6 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static std::optional<CompileTimeSource>
|
static std::optional<CompileTimeSource>
|
||||||
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||||
if (!op)
|
if (!op)
|
||||||
@@ -217,7 +216,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
|||||||
chainLength += 1;
|
chainLength += 1;
|
||||||
|
|
||||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||||
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
|
return hasConstantIndices(extractOp)
|
||||||
|
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
|
||||||
|
: std::nullopt;
|
||||||
|
|
||||||
if (!isStaticTensorResult(op))
|
if (!isStaticTensorResult(op))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@@ -232,7 +233,8 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
|||||||
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
return hasStaticUnitStrides(extractSliceOp)
|
||||||
|
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||||
: std::nullopt;
|
: std::nullopt;
|
||||||
|
|
||||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||||
@@ -252,10 +254,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
|||||||
res = partialRes;
|
res = partialRes;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if(res->chainLength < partialRes->chainLength){
|
if (res->chainLength < partialRes->chainLength)
|
||||||
res = partialRes;
|
res = partialRes;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +265,6 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
return getCompileTimeSourceImpl(op, visited);
|
return getCompileTimeSourceImpl(op, visited);
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
@@ -143,8 +143,7 @@ static Value createGemmBatchKOffset(
|
|||||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createGemmBatchHOffset(
|
static Value createGemmBatchHOffset(Value lane,
|
||||||
Value lane,
|
|
||||||
int64_t numOutRows,
|
int64_t numOutRows,
|
||||||
int64_t numKSlices,
|
int64_t numKSlices,
|
||||||
int64_t numOutHSlices,
|
int64_t numOutHSlices,
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,8 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|||||||
@@ -171,7 +171,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
|||||||
markOpToRemove(receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||||
|
|||||||
@@ -606,7 +606,6 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
markOpToRemove(receiveOp);
|
markOpToRemove(receiveOp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|||||||
@@ -122,9 +122,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
|||||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
||||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||||
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
|
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||||
<< " scf_for=" << counts.scfForCount << "\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
@@ -514,7 +513,8 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
||||||
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
collectedData.push_back(
|
||||||
|
{nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
||||||
totalComputeOps += 1;
|
totalComputeOps += 1;
|
||||||
totalLogicalComputes += logicalCount;
|
totalLogicalComputes += logicalCount;
|
||||||
totalBatchComputeOps += 1;
|
totalBatchComputeOps += 1;
|
||||||
|
|||||||
@@ -206,10 +206,8 @@ static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
|
|||||||
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
|
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<int64_t> evaluateIndexLike(Value value,
|
static FailureOr<int64_t>
|
||||||
const DenseMap<Value, int64_t>& bindings,
|
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg) {
|
||||||
std::optional<uint32_t> lane,
|
|
||||||
Value laneArg) {
|
|
||||||
if (lane && value == laneArg)
|
if (lane && value == laneArg)
|
||||||
return *lane;
|
return *lane;
|
||||||
if (auto it = bindings.find(value); it != bindings.end())
|
if (auto it = bindings.find(value); it != bindings.end())
|
||||||
@@ -260,8 +258,7 @@ static FailureOr<int64_t> evaluateIndexLike(Value value,
|
|||||||
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<int64_t, 4>>
|
static FailureOr<SmallVector<int64_t, 4>> evaluateIndexList(ArrayRef<OpFoldResult> values,
|
||||||
evaluateIndexList(ArrayRef<OpFoldResult> values,
|
|
||||||
const DenseMap<Value, int64_t>& bindings,
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
std::optional<uint32_t> lane,
|
std::optional<uint32_t> lane,
|
||||||
Value laneArg) {
|
Value laneArg) {
|
||||||
@@ -308,8 +305,7 @@ static CrossbarWeight completeCrossbarWeight(Value root,
|
|||||||
return weight;
|
return weight;
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<CrossbarWeight>
|
static FailureOr<CrossbarWeight> getStaticCrossbarWeight(Operation* owner,
|
||||||
getStaticCrossbarWeight(Operation* owner,
|
|
||||||
Value value,
|
Value value,
|
||||||
const DenseMap<Value, int64_t>& bindings,
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
std::optional<uint32_t> lane,
|
std::optional<uint32_t> lane,
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ using CPU = int;
|
|||||||
using Cost = unsigned long long;
|
using Cost = unsigned long long;
|
||||||
using Time = unsigned long long;
|
using Time = unsigned long long;
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T checkedAdd(T lhs, T rhs) {
|
inline T checkedAdd(T lhs, T rhs) {
|
||||||
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
|
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
|
||||||
|
|||||||
@@ -327,9 +327,8 @@ private:
|
|||||||
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
|
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
|
||||||
const StaticValueKnowledge& initialKnowledge,
|
const StaticValueKnowledge& initialKnowledge,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
return walkPimCoreBlockStructurally(coreLikeOp.getBody().front(),
|
return walkPimCoreBlockStructurally(
|
||||||
initialKnowledge,
|
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
[&](Operation& op, const StaticValueKnowledge& knowledge) {
|
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
if (!isSupportedCoreInstructionOp(&op)) {
|
if (!isSupportedCoreInstructionOp(&op)) {
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
|||||||
Reference in New Issue
Block a user