94 lines
3.7 KiB
C++
94 lines
3.7 KiB
C++
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/Support/FileSystem.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <cassert>
|
|
|
|
#include "Common/Support/CheckedArithmetic.hpp"
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {} // namespace
|
|
|
|
WeightEmissionResult createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
|
auto coreWeightsDirPath = outputDirPath + "/weights";
|
|
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
|
assert(!error && "Error creating weights directory");
|
|
size_t indexFileName = 0;
|
|
|
|
int64_t xbarSize = crossbarSize.getValue();
|
|
WeightEmissionResult result;
|
|
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
|
|
|
|
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
|
|
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
|
|
it != materializedWeights.end())
|
|
return it->second;
|
|
|
|
auto globalOp = weightView.globalOp;
|
|
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
|
assert(denseAttr && "Weight global must have dense initial value");
|
|
|
|
ArrayRef<int64_t> shape = weightView.shape;
|
|
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
|
int64_t numRows = shape[0];
|
|
int64_t numCols = shape[1];
|
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
|
|
|
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
|
|
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
|
std::error_code errorCode;
|
|
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
|
if (errorCode) {
|
|
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
|
assert(errorCode);
|
|
}
|
|
|
|
uint64_t zero = 0;
|
|
for (int64_t row = 0; row < xbarSize; row++) {
|
|
for (int64_t col = 0; col < xbarSize; col++) {
|
|
if (row < numRows && col < numCols) {
|
|
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
|
|
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
|
uint64_t word = bits.getZExtValue();
|
|
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
|
}
|
|
else {
|
|
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
|
}
|
|
}
|
|
}
|
|
|
|
weightFileStream.close();
|
|
materializedWeights.push_back({weightView, newFileName});
|
|
uint64_t weightBytes = pim::checkedMulOrCrash(
|
|
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
|
|
elementByteWidth,
|
|
"weight byte size");
|
|
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
|
|
return newFileName;
|
|
};
|
|
|
|
for (const WeightFileRequest& request : requests) {
|
|
auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
|
|
coreFiles.reserve(request.weights.size());
|
|
for (const ResolvedWeightView& weight : request.weights)
|
|
coreFiles.push_back(materializeWeight(weight));
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|