fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:15:10 +02:00
parent 013ae0ac2a
commit 1a5d7d2a3f
10 changed files with 349 additions and 317 deletions
+53 -46
View File
@@ -814,7 +814,7 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
struct CoreEmissionResult {
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
MemoryReportRow reportRow;
llvm::SmallVector<unsigned, 8> usedWeightIndices;
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
};
template <typename MapTy>
@@ -879,7 +879,6 @@ struct CompiledCoreNode {
Kind kind = Kind::Op;
Operation* op = nullptr;
CompiledCoreOpKind opKind = CompiledCoreOpKind::Load;
std::optional<unsigned> weightIndex;
CompiledIndexExpr lowerBound;
CompiledIndexExpr upperBound;
CompiledIndexExpr step;
@@ -978,12 +977,6 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
opNode.kind = CompiledCoreNode::Kind::Op;
opNode.op = &op;
opNode.opKind = *opKind;
if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
auto weightIndex = onnx_mlir::resolveWeightIndex(weightOwner, vmmOp);
if (!weightIndex)
return failure();
opNode.weightIndex = *weightIndex;
}
plan.push_back(std::move(opNode));
}
return success();
@@ -992,6 +985,9 @@ static LogicalResult compileCoreEmissionPlan(Block& block,
static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
PimCodeGen& coreCodeGen,
StaticValueKnowledge& knowledge,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
const StaticValueKnowledge&)>
resolveWeightSlot,
size_t& processedOperations,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
@@ -1015,7 +1011,7 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
aliasBindings.bind(iterArg, iterValue);
if (failed(executeCompiledCorePlan(
*node.loopBody, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount)))
*node.loopBody, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount)))
return failure();
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
@@ -1048,9 +1044,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl<Compile
coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Vmm:
assert(node.weightIndex && "compiled VMM op must have cached weight index");
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(
*node.weightIndex, cast<pim::PimVMMOp>(node.op), true, knowledge);
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
else
return failure();
break;
case CompiledCoreOpKind::Transpose:
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
@@ -1138,6 +1135,9 @@ static int64_t codeGenCoreOps(Block& block,
PimCodeGen& coreCodeGen,
const StaticValueKnowledge& initialKnowledge,
Operation* weightOwner,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp,
const StaticValueKnowledge&)>
resolveWeightSlot,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
llvm::SmallVector<CompiledCoreNode, 32> plan;
@@ -1146,7 +1146,8 @@ static int64_t codeGenCoreOps(Block& block,
size_t processedOperations = 0;
StaticValueKnowledge knowledge = initialKnowledge;
auto result = executeCompiledCorePlan(plan, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount);
auto result =
executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
@@ -1174,9 +1175,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
@@ -1238,11 +1236,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
++nextBatchReportId;
}
auto linkCoreWeights = [&](size_t originalCoreId,
size_t coreId,
ArrayRef<unsigned> usedIndices,
ValueRange weights,
Operation* weightOwner,
auto linkCoreWeights = [&](size_t coreId,
ArrayRef<std::string> weightFiles,
json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
@@ -1250,20 +1245,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
return InvalidOutputFileAccess;
}
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
for (unsigned index : usedIndices) {
if (index >= weights.size()) {
weightOwner->emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < weights.size() && "Weight index is out of range");
}
mlir::Value weight = weights[index];
xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
xbarsPerGroup.push_back(static_cast<int64_t>(slot));
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")
<< "\nError:" << error.message() << '\n';
return InvalidOutputFileAccess;
}
@@ -1275,6 +1262,22 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto emitJob = [&](const CoreEmissionJob& job) -> CoreEmissionResult {
CoreEmissionResult result;
PimAcceleratorMemory jobMemory(memory.memEntriesMap, false);
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
auto resolveWeightSlot = [&](pim::PimVMMOp vmmOp,
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
if (failed(weightView)) {
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
return failure();
}
if (auto it = llvm::find(usedWeights, *weightView); it != usedWeights.end())
return static_cast<unsigned>(std::distance(usedWeights.begin(), it));
usedWeights.push_back(*weightView);
return static_cast<unsigned>(usedWeights.size() - 1);
};
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".pim";
@@ -1307,21 +1310,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations =
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation());
int64_t processedOperations = codeGenCoreOps(
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
if (processedOperations < 0) {
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
result.reportRow = deviceMemory.getReportRow();
result.usedWeightIndices = getUsedWeightIndices(coreOp);
result.usedWeights = std::move(usedWeights);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(job.coreLikeOp);
aliasMaterializedHostGlobals(coreBatchOp, moduleOp, materializedHostGlobals, jobMemory);
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
result.usedWeightIndices = getUsedWeightIndices(coreBatchOp);
for (unsigned lane : job.lanes) {
StaticValueKnowledge knowledge;
@@ -1335,6 +1337,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
coreCodeGen,
knowledge,
coreBatchOp.getOperation(),
resolveWeightSlot,
lane,
static_cast<unsigned>(coreBatchOp.getLaneCount()));
if (processedOperations < 0) {
@@ -1345,6 +1348,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
}
result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights);
}
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
@@ -1368,14 +1372,23 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (jobResults[jobIndex].status != CompilerSuccess)
return jobResults[jobIndex].status;
llvm::SmallVector<WeightFileRequest, 8> weightRequests;
weightRequests.reserve(jobs.size());
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
WeightFileRequest request;
request.coreId = jobs[jobIndex].emittedCoreId;
request.weights = jobResults[jobIndex].usedWeights;
weightRequests.push_back(std::move(request));
}
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath);
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
const CoreEmissionJob& job = jobs[jobIndex];
const CoreEmissionResult& result = jobResults[jobIndex];
json::Array xbarsPerGroup;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(job.coreLikeOp)) {
if (auto err = linkCoreWeights(
job.originalCoreId, job.emittedCoreId, result.usedWeightIndices, coreOp.getWeights(), coreOp.getOperation(), xbarsPerGroup))
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
memory.recordCoreReport(job.emittedCoreId, result.reportRow);
@@ -1391,14 +1404,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
for (size_t jobIndex : group) {
const CoreEmissionJob& job = jobs[jobIndex];
const CoreEmissionResult& result = jobResults[jobIndex];
auto coreBatchOp = cast<pim::PimCoreBatchOp>(job.coreLikeOp);
json::Array xbarsPerGroup;
if (auto err = linkCoreWeights(job.originalCoreId,
job.emittedCoreId,
result.usedWeightIndices,
coreBatchOp.getWeights(),
coreBatchOp.getOperation(),
xbarsPerGroup))
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
+50 -227
View File
@@ -1,25 +1,16 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <type_traits>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
@@ -27,240 +18,72 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<Operation*> viewOps;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
viewOps.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
}
return view;
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
llvm::DenseMap<mlir::Value, std::string> mapWeightValueToFileName;
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
it != materializedWeights.end())
return it->second;
for (Operation* op : coreLikeOps) {
auto processWeight = [&](Operation* ownerOp,
mlir::Value weight,
size_t weightIndex,
size_t coreId) -> LogicalResult {
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
auto globalOp = weightView.globalOp;
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
assert(denseAttr && "Weight global must have dense initial value");
if (mapCoreWeightToFileName[coreId].contains(weight))
return success();
ArrayRef<int64_t> shape = weightView.shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) {
mapCoreWeightToFileName[coreId].insert({weight, weightFile->second});
return success();
}
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapWeightValueToFileName[weight] = fileName;
mapCoreWeightToFileName[coreId].insert({weight, fileName});
return success();
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapWeightValueToFileName[weight] = newFileName;
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
return success();
};
auto processCoreLike = [&](auto coreLikeOp) {
auto usedIndices = getUsedWeightIndices(coreLikeOp);
for (unsigned index : usedIndices) {
if (index >= coreLikeOp.getWeights().size()) {
coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range");
}
}
if constexpr (std::is_same_v<std::decay_t<decltype(coreLikeOp)>, pim::PimCoreOp>) {
size_t coreId = static_cast<size_t>(coreLikeOp.getCoreId());
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
else {
auto batchCoreIds = getBatchCoreIds(coreLikeOp);
SmallVector<size_t> orderedCoreIds;
llvm::SmallSet<size_t, 8> seenCoreIds;
for (int32_t coreId : batchCoreIds)
if (seenCoreIds.insert(static_cast<size_t>(coreId)).second)
orderedCoreIds.push_back(static_cast<size_t>(coreId));
for (size_t coreId : orderedCoreIds)
for (unsigned index : usedIndices)
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
return failure();
return success();
}
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCoreLike(coreOp);
continue;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
(void) processCoreLike(cast<pim::PimCoreBatchOp>(op));
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
materializedWeights.push_back({weightView, newFileName});
return newFileName;
};
for (const WeightFileRequest& request : requests) {
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight));
}
return mapCoreWeightToFileName;
}
+10 -3
View File
@@ -1,16 +1,23 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
struct WeightFileRequest {
size_t coreId = 0;
llvm::SmallVector<ResolvedWeightView, 8> weights;
};
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
} // namespace onnx_mlir