fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -19,6 +24,50 @@ void markWeightAlways(mlir::Operation* op) {
|
||||
|
||||
namespace {
|
||||
|
||||
CompiledIndexExpr makeConstantExpr(int64_t constant) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constant;
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {std::move(lhs), std::move(rhs)};
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
mlir::Value stripWeightViewOps(mlir::Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
@@ -96,35 +145,31 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreBatchOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||
weight = stripWeightViewOps(weight);
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
if (coreOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -132,4 +177,121 @@ std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::Pi
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
return resolveWeightIndex(weightOwner, vmmOp.getWeight());
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallVector<mlir::Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
|
||||
while (true) {
|
||||
if (auto defOp = current.getDefiningOp()) {
|
||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return mlir::failure();
|
||||
|
||||
ResolvedWeightView view;
|
||||
view.globalOp = globalOp;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||
llvm::SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getMixedOffsets().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return mlir::failure();
|
||||
offsetValue = makeConstantExpr(intAttr.getInt());
|
||||
}
|
||||
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
|
||||
auto compiledOffset = compileIndexExpr(value);
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
offsetValue = *compiledOffset;
|
||||
}
|
||||
else {
|
||||
return mlir::failure();
|
||||
}
|
||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
}
|
||||
}
|
||||
|
||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
view.offset = *resolvedOffset;
|
||||
return view;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
||||
current = subview.getSource();
|
||||
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
||||
current = collapse.getSrc();
|
||||
else
|
||||
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||
current = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||
if (!weightIndex)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
current = coreOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
current = coreBatchOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@@ -10,12 +11,24 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedWeightView {
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
|
||||
bool operator==(const ResolvedWeightView& other) const {
|
||||
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
|
||||
}
|
||||
};
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
@@ -32,24 +45,21 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
|
||||
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
|
||||
indices.push_back(*weightIndex);
|
||||
});
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -814,7 +814,7 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
struct CoreEmissionResult {
|
||||
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
|
||||
MemoryReportRow reportRow;
|
||||
llvm::SmallVector<unsigned, 8> usedWeightIndices;
|
||||
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
|
||||
};
|
||||
|
||||
template <typename MapTy>
|
||||
@@ -879,7 +879,6 @@ struct CompiledCoreNode {
|
||||
Kind kind = Kind::Op;
|
||||
Operation* op = nullptr;
|
||||
CompiledCoreOpKind opKind = CompiledCoreOpKind::Load;
|
||||
std::optional<unsigned> weightIndex;
|
||||
CompiledIndexExpr lowerBound;
|
||||
CompiledIndexExpr upperBound;
|
||||
CompiledIndexExpr step;
|
||||
@@ -978,12 +977,6 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
|
||||
opNode.kind = CompiledCoreNode::Kind::Op;
|
||||
opNode.op = &op;
|
||||
opNode.opKind = *opKind;
|
||||
if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||
auto weightIndex = onnx_mlir::resolveWeightIndex(weightOwner, vmmOp);
|
||||
if (!weightIndex)
|
||||
return failure();
|
||||
opNode.weightIndex = *weightIndex;
|
||||
}
|
||||
plan.push_back(std::move(opNode));
|
||||
}
|
||||
return success();
|
||||
@@ -992,6 +985,9 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
|
||||
static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
|
||||
PimCodeGen& coreCodeGen,
|
||||
StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
|
||||
const StaticValueKnowledge&)>
|
||||
resolveWeightSlot,
|
||||
size_t& processedOperations,
|
||||
std::optional<unsigned> batchLane = std::nullopt,
|
||||
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
||||
@@ -1015,7 +1011,7 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
||||
aliasBindings.bind(iterArg, iterValue);
|
||||
|
||||
if (failed(executeCompiledCorePlan(
|
||||
*node.loopBody, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount)))
|
||||
*node.loopBody, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount)))
|
||||
return failure();
|
||||
|
||||
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
|
||||
@@ -1048,9 +1044,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
|
||||
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
|
||||
break;
|
||||
case CompiledCoreOpKind::Vmm:
|
||||
assert(node.weightIndex && "compiled VMM op must have cached weight index");
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(
|
||||
*node.weightIndex, cast<pim::PimVMMOp>(node.op), true, knowledge);
|
||||
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
|
||||
else
|
||||
return failure();
|
||||
break;
|
||||
case CompiledCoreOpKind::Transpose:
|
||||
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
|
||||
@@ -1138,6 +1135,9 @@ static int64_t codeGenCoreOps(Block& block,
|
||||
PimCodeGen& coreCodeGen,
|
||||
const StaticValueKnowledge& initialKnowledge,
|
||||
Operation* weightOwner,
|
||||
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
|
||||
const StaticValueKnowledge&)>
|
||||
resolveWeightSlot,
|
||||
std::optional<unsigned> batchLane = std::nullopt,
|
||||
std::optional<unsigned> batchLaneCount = std::nullopt) {
|
||||
llvm::SmallVector<CompiledCoreNode, 32> plan;
|
||||
@@ -1146,7 +1146,8 @@ static int64_t codeGenCoreOps(Block& block,
|
||||
|
||||
size_t processedOperations = 0;
|
||||
StaticValueKnowledge knowledge = initialKnowledge;
|
||||
auto result = executeCompiledCorePlan(plan, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount);
|
||||
auto result =
|
||||
executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
|
||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||
}
|
||||
|
||||
@@ -1174,9 +1175,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
size_t maxCoreId = 0;
|
||||
uint64_t nextBatchReportId = 0;
|
||||
|
||||
// Create Weight Folder
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
||||
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
||||
@@ -1238,11 +1236,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
++nextBatchReportId;
|
||||
}
|
||||
|
||||
auto linkCoreWeights = [&](size_t originalCoreId,
|
||||
size_t coreId,
|
||||
ArrayRef<unsigned> usedIndices,
|
||||
ValueRange weights,
|
||||
Operation* weightOwner,
|
||||
auto linkCoreWeights = [&](size_t coreId,
|
||||
ArrayRef<std::string> weightFiles,
|
||||
json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
|
||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||
@@ -1250,20 +1245,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
|
||||
for (unsigned index : usedIndices) {
|
||||
if (index >= weights.size()) {
|
||||
weightOwner->emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < weights.size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = weights[index];
|
||||
xbarsPerGroup.push_back(index);
|
||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||
auto& fileName = mapWeightToFile[weight];
|
||||
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
|
||||
xbarsPerGroup.push_back(static_cast<int64_t>(slot));
|
||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
||||
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
|
||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
|
||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")
|
||||
<< "\nError:" << error.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
@@ -1275,6 +1262,22 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
auto emitJob = [&](const CoreEmissionJob& job) -> CoreEmissionResult {
|
||||
CoreEmissionResult result;
|
||||
PimAcceleratorMemory jobMemory(memory.memEntriesMap, false);
|
||||
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
|
||||
|
||||
auto resolveWeightSlot = [&](pim::PimVMMOp vmmOp,
|
||||
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
|
||||
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
|
||||
if (failed(weightView)) {
|
||||
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (auto it = llvm::find(usedWeights, *weightView); it != usedWeights.end())
|
||||
return static_cast<unsigned>(std::distance(usedWeights.begin(), it));
|
||||
|
||||
usedWeights.push_back(*weightView);
|
||||
return static_cast<unsigned>(usedWeights.size() - 1);
|
||||
};
|
||||
|
||||
std::error_code errorCode;
|
||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".pim";
|
||||
@@ -1307,21 +1310,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||
deviceMemory.allocateCore(coreOp);
|
||||
|
||||
int64_t processedOperations =
|
||||
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation());
|
||||
int64_t processedOperations = codeGenCoreOps(
|
||||
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
|
||||
if (processedOperations < 0) {
|
||||
result.status = CompilerFailure;
|
||||
return result;
|
||||
}
|
||||
assert(processedOperations > 0);
|
||||
result.reportRow = deviceMemory.getReportRow();
|
||||
result.usedWeightIndices = getUsedWeightIndices(coreOp);
|
||||
result.usedWeights = std::move(usedWeights);
|
||||
}
|
||||
else {
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(job.coreLikeOp);
|
||||
aliasMaterializedHostGlobals(coreBatchOp, moduleOp, materializedHostGlobals, jobMemory);
|
||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||
result.usedWeightIndices = getUsedWeightIndices(coreBatchOp);
|
||||
|
||||
for (unsigned lane : job.lanes) {
|
||||
StaticValueKnowledge knowledge;
|
||||
@@ -1335,6 +1337,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
coreCodeGen,
|
||||
knowledge,
|
||||
coreBatchOp.getOperation(),
|
||||
resolveWeightSlot,
|
||||
lane,
|
||||
static_cast<unsigned>(coreBatchOp.getLaneCount()));
|
||||
if (processedOperations < 0) {
|
||||
@@ -1345,6 +1348,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
}
|
||||
|
||||
result.reportRow = deviceMemory.getReportRow();
|
||||
result.usedWeights = std::move(usedWeights);
|
||||
}
|
||||
|
||||
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
|
||||
@@ -1368,14 +1372,23 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
if (jobResults[jobIndex].status != CompilerSuccess)
|
||||
return jobResults[jobIndex].status;
|
||||
|
||||
llvm::SmallVector<WeightFileRequest, 8> weightRequests;
|
||||
weightRequests.reserve(jobs.size());
|
||||
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
|
||||
WeightFileRequest request;
|
||||
request.coreId = jobs[jobIndex].emittedCoreId;
|
||||
request.weights = jobResults[jobIndex].usedWeights;
|
||||
weightRequests.push_back(std::move(request));
|
||||
}
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath);
|
||||
|
||||
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
|
||||
const CoreEmissionJob& job = jobs[jobIndex];
|
||||
const CoreEmissionResult& result = jobResults[jobIndex];
|
||||
json::Array xbarsPerGroup;
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(job.coreLikeOp)) {
|
||||
if (auto err = linkCoreWeights(
|
||||
job.originalCoreId, job.emittedCoreId, result.usedWeightIndices, coreOp.getWeights(), coreOp.getOperation(), xbarsPerGroup))
|
||||
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
|
||||
return err;
|
||||
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
|
||||
memory.recordCoreReport(job.emittedCoreId, result.reportRow);
|
||||
@@ -1391,14 +1404,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
for (size_t jobIndex : group) {
|
||||
const CoreEmissionJob& job = jobs[jobIndex];
|
||||
const CoreEmissionResult& result = jobResults[jobIndex];
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(job.coreLikeOp);
|
||||
json::Array xbarsPerGroup;
|
||||
if (auto err = linkCoreWeights(job.originalCoreId,
|
||||
job.emittedCoreId,
|
||||
result.usedWeightIndices,
|
||||
coreBatchOp.getWeights(),
|
||||
coreBatchOp.getOperation(),
|
||||
xbarsPerGroup))
|
||||
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
|
||||
return err;
|
||||
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
|
||||
|
||||
@@ -1,25 +1,16 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <type_traits>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -27,240 +18,72 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct DenseWeightView {
|
||||
DenseElementsAttr denseAttr;
|
||||
SmallVector<int64_t> shape;
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = current.getDefiningOp();
|
||||
if (!defOp)
|
||||
return failure();
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return view;
|
||||
}
|
||||
|
||||
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
SmallVector<Operation*> coreLikeOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||
coreLikeOps.push_back(&op);
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
llvm::DenseMap<mlir::Value, std::string> mapWeightValueToFileName;
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
|
||||
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
|
||||
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
|
||||
it != materializedWeights.end())
|
||||
return it->second;
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
auto processWeight = [&](Operation* ownerOp,
|
||||
mlir::Value weight,
|
||||
size_t weightIndex,
|
||||
size_t coreId) -> LogicalResult {
|
||||
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||
if (failed(weightView)) {
|
||||
ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
|
||||
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||
}
|
||||
auto globalOp = weightView.globalOp;
|
||||
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
assert(denseAttr && "Weight global must have dense initial value");
|
||||
|
||||
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||
return success();
|
||||
ArrayRef<int64_t> shape = weightView.shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) {
|
||||
mapCoreWeightToFileName[coreId].insert({weight, weightFile->second});
|
||||
return success();
|
||||
}
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
mapWeightValueToFileName[weight] = fileName;
|
||||
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||
return success();
|
||||
}
|
||||
|
||||
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||
ArrayRef<int64_t> shape = weightView->shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
if (globalOp)
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapWeightValueToFileName[weight] = newFileName;
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
return success();
|
||||
};
|
||||
|
||||
auto processCoreLike = [&](auto coreLikeOp) {
|
||||
auto usedIndices = getUsedWeightIndices(coreLikeOp);
|
||||
for (unsigned index : usedIndices) {
|
||||
if (index >= coreLikeOp.getWeights().size()) {
|
||||
coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(coreLikeOp)>, pim::PimCoreOp>) {
|
||||
size_t coreId = static_cast<size_t>(coreLikeOp.getCoreId());
|
||||
for (unsigned index : usedIndices)
|
||||
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
else {
|
||||
auto batchCoreIds = getBatchCoreIds(coreLikeOp);
|
||||
SmallVector<size_t> orderedCoreIds;
|
||||
llvm::SmallSet<size_t, 8> seenCoreIds;
|
||||
for (int32_t coreId : batchCoreIds)
|
||||
if (seenCoreIds.insert(static_cast<size_t>(coreId)).second)
|
||||
orderedCoreIds.push_back(static_cast<size_t>(coreId));
|
||||
|
||||
for (size_t coreId : orderedCoreIds)
|
||||
for (unsigned index : usedIndices)
|
||||
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
(void) processCoreLike(coreOp);
|
||||
continue;
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
(void) processCoreLike(cast<pim::PimCoreBatchOp>(op));
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
materializedWeights.push_back({weightView, newFileName});
|
||||
return newFileName;
|
||||
};
|
||||
|
||||
for (const WeightFileRequest& request : requests) {
|
||||
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
|
||||
coreFiles.reserve(request.weights.size());
|
||||
for (const ResolvedWeightView& weight : request.weights)
|
||||
coreFiles.push_back(materializeWeight(weight));
|
||||
}
|
||||
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
|
||||
struct WeightFileRequest {
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<ResolvedWeightView, 8> weights;
|
||||
};
|
||||
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -30,6 +30,15 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
return memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
}
|
||||
|
||||
FailureOr<Value>
|
||||
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
|
||||
if (isa<BufferLikeType>(value.getType()))
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
|
||||
mlir::Value value,
|
||||
|
||||
@@ -431,8 +431,11 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt);
|
||||
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -473,9 +476,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
|
||||
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -512,9 +516,10 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
||||
|
||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(
|
||||
rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt);
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -546,8 +551,9 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt);
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -240,10 +240,10 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||
@@ -276,13 +276,13 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
@@ -365,13 +365,14 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||
@@ -423,13 +424,13 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -346,6 +347,11 @@ private:
|
||||
if (isCoreWeightBlockArgument(operand))
|
||||
continue;
|
||||
|
||||
if (auto vmmOp = dyn_cast<pim::PimVMMOp>(&op);
|
||||
vmmOp && operandIndex == 0 && resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
|
||||
Reference in New Issue
Block a user