Refactor PIM/Common (splitting in files, adding helpers, adding brief
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m36s
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -12,4 +12,4 @@ build
|
||||
cmake-build-debug
|
||||
cmake-build-release
|
||||
|
||||
**/__pycache__
|
||||
**/__*
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
add_pim_library(OMPimCommon
|
||||
PimCommon.cpp
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/WeightUtils.cpp
|
||||
Support/DebugDump.cpp
|
||||
Support/Diagnostics.cpp
|
||||
Support/FileSystemUtils.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
258
src/PIM/Common/IR/AddressAnalysis.cpp
Normal file
258
src/PIM/Common/IR/AddressAnalysis.cpp
Normal file
@@ -0,0 +1,258 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
||||
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
||||
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return mlir::failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return mlir::failure();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (!result)
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return mlir::failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return mlir::failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
43
src/PIM/Common/IR/AddressAnalysis.hpp
Normal file
43
src/PIM/Common/IR/AddressAnalysis.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Describes a value as a base addressable object plus a statically known
|
||||
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
/// Records compile-time facts used when interpreting address arithmetic and
|
||||
/// loop-carried aliases inside PIM regions.
|
||||
struct StaticValueKnowledge {
|
||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
/// Resolves a value to contiguous backing storage when that storage can be
|
||||
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Statically evaluates index-like SSA values, including simple integer
|
||||
/// arithmetic and loop facts recorded in `knowledge`.
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||
/// loop-carried memref/result.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
67
src/PIM/Common/IR/CoreBlockUtils.cpp
Normal file
67
src/PIM/Common/IR/CoreBlockUtils.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
return mlir::isa<mlir::arith::ConstantOp,
|
||||
mlir::arith::AddIOp,
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
mlir::memref::SubViewOp,
|
||||
mlir::memref::CastOp,
|
||||
mlir::memref::CollapseShapeOp,
|
||||
mlir::memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
24
src/PIM/Common/IR/CoreBlockUtils.hpp
Normal file
24
src/PIM/Common/IR/CoreBlockUtils.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns true for ops in a `pim.core` body that only participate in static
|
||||
/// address or index computation and therefore do not emit PIM instructions.
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||
|
||||
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
|
||||
/// their bounds are known and invoking `callback` only on instruction-emitting
|
||||
/// operations.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
45
src/PIM/Common/IR/EntryPointUtils.cpp
Normal file
45
src/PIM/Common/IR/EntryPointUtils.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return mlir::failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return mlir::failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
13
src/PIM/Common/IR/EntryPointUtils.hpp
Normal file
13
src/PIM/Common/IR/EntryPointUtils.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Resolves the function the PIM pipeline should treat as its entry point.
|
||||
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
|
||||
/// non-external function if the module is otherwise unambiguous.
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
89
src/PIM/Common/IR/ShapeUtils.cpp
Normal file
89
src/PIM/Common/IR/ShapeUtils.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
||||
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
||||
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
indices[dim] = linearIndex / stride;
|
||||
linearIndex %= stride;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
||||
int64_t linearIndex = 0;
|
||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||
linearIndex += index * stride;
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dim : shape)
|
||||
numElements *= dim;
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides) {
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
22
src/PIM/Common/IR/ShapeUtils.hpp
Normal file
22
src/PIM/Common/IR/ShapeUtils.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
101
src/PIM/Common/IR/WeightUtils.cpp
Normal file
101
src/PIM/Common/IR/WeightUtils.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(mlir::Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||
mlir::Operation* user = use.getOwner();
|
||||
unsigned operandIndex = use.getOperandNumber();
|
||||
|
||||
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
||||
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
mlir::Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
29
src/PIM/Common/IR/WeightUtils.hpp
Normal file
29
src/PIM/Common/IR/WeightUtils.hpp
Normal file
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
/// weight across later PIM lowering/codegen stages.
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||
|
||||
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
|
||||
/// allowing later passes to preserve it as a dedicated weight-like object.
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
|
||||
/// Visits weight operands consumed by Pim core ops/core batches so downstream
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,575 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() {
|
||||
if (outputBaseName.empty() || outputBaseName == "-")
|
||||
return {};
|
||||
|
||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||
if (lastSlash == std::string::npos)
|
||||
return ".";
|
||||
return outputBaseName.substr(0, lastSlash);
|
||||
}
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
OpPrintingFlags flags;
|
||||
flags.elideLargeElementsAttrs();
|
||||
moduleOp.print(os, flags);
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
bool found = false;
|
||||
parentOp.walk([&](Operation* op) {
|
||||
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||
Operation* user = use.getOwner();
|
||||
unsigned operandIndex = use.getOperandNumber();
|
||||
|
||||
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
|
||||
SmallPtrSet<Value, 8> visited;
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
});
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||
SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
indices[dim] = linearIndex / stride;
|
||||
linearIndex %= stride;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||
int64_t linearIndex = 0;
|
||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||
linearIndex += index * stride;
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dim : shape)
|
||||
numElements *= dim;
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
ArrayRef<int64_t> offsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides) {
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
||||
// and when propagating yielded values across iterations during static unrolling.
|
||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = dyn_cast<OpResult>(value))
|
||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
||||
// (every iteration yields the same backing memory in the DPS sense).
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
bool isCoreStaticAddressOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
arith::AddIOp,
|
||||
arith::SubIOp,
|
||||
arith::MulIOp,
|
||||
arith::DivUIOp,
|
||||
arith::RemUIOp,
|
||||
arith::IndexCastOp,
|
||||
memref::AllocOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
LogicalResult walkPimCoreBlock(Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : block) {
|
||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -11,83 +11,17 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
|
||||
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
struct StaticValueKnowledge {
|
||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
||||
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
||||
/// only contribute to static addressing or index computations (arith integer math,
|
||||
/// memref view ops, memref.alloc, arith.constant).
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||
|
||||
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
||||
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
||||
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
||||
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
||||
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
27
src/PIM/Common/Support/DebugDump.cpp
Normal file
27
src/PIM/Common/Support/DebugDump.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
mlir::OpPrintingFlags flags;
|
||||
flags.elideLargeElementsAttrs();
|
||||
moduleOp.print(os, flags);
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
13
src/PIM/Common/Support/DebugDump.hpp
Normal file
13
src/PIM/Common/Support/DebugDump.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Emits a MLIR snapshot under the current compiler output
|
||||
/// directory for pass-level debugging.
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
41
src/PIM/Common/Support/Diagnostics.cpp
Normal file
41
src/PIM/Common/Support/Diagnostics.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
|
||||
return op->emitOpError() << "requires statically shaped " << valueDescription;
|
||||
}
|
||||
|
||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||
llvm::StringRef valueDescription,
|
||||
int64_t actualRank,
|
||||
llvm::ArrayRef<int64_t> supportedRanks) {
|
||||
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
|
||||
if (supportedRanks.empty())
|
||||
return diag;
|
||||
|
||||
diag << "; supported rank";
|
||||
if (supportedRanks.size() != 1)
|
||||
diag << 's';
|
||||
diag << ' ';
|
||||
|
||||
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
|
||||
return diag;
|
||||
}
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
|
||||
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
|
||||
}
|
||||
|
||||
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
|
||||
llvm::StringRef action,
|
||||
llvm::StringRef path,
|
||||
const std::error_code& errorCode) {
|
||||
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
38
src/PIM/Common/Support/Diagnostics.hpp
Normal file
38
src/PIM/Common/Support/Diagnostics.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
|
||||
/// accepted by the current lowering/codegen path.
|
||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||
llvm::StringRef valueDescription,
|
||||
int64_t actualRank,
|
||||
llvm::ArrayRef<int64_t> supportedRanks);
|
||||
|
||||
/// Emits a consistent diagnostic for missing symbol/global references.
|
||||
mlir::InFlightDiagnostic
|
||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
|
||||
|
||||
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
|
||||
/// the relevant IR location.
|
||||
mlir::LogicalResult
|
||||
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
|
||||
|
||||
template <typename T>
|
||||
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
|
||||
return mlir::success(succeeded(value));
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
24
src/PIM/Common/Support/FileSystemUtils.cpp
Normal file
24
src/PIM/Common/Support/FileSystemUtils.cpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#include <filesystem>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() {
|
||||
if (outputBaseName.empty() || outputBaseName == "-")
|
||||
return {};
|
||||
|
||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||
if (lastSlash == std::string::npos)
|
||||
return ".";
|
||||
return outputBaseName.substr(0, lastSlash);
|
||||
}
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
13
src/PIM/Common/Support/FileSystemUtils.hpp
Normal file
13
src/PIM/Common/Support/FileSystemUtils.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns the directory that should hold PIM artifacts/debug dumps for the
|
||||
/// current compiler invocation.
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user