fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,25 +1,16 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <type_traits>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -27,240 +18,72 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct DenseWeightView {
|
||||
DenseElementsAttr denseAttr;
|
||||
SmallVector<int64_t> shape;
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = current.getDefiningOp();
|
||||
if (!defOp)
|
||||
return failure();
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return view;
|
||||
}
|
||||
|
||||
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
SmallVector<Operation*> coreLikeOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||
coreLikeOps.push_back(&op);
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
llvm::DenseMap<mlir::Value, std::string> mapWeightValueToFileName;
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
|
||||
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
|
||||
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
|
||||
it != materializedWeights.end())
|
||||
return it->second;
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
auto processWeight = [&](Operation* ownerOp,
|
||||
mlir::Value weight,
|
||||
size_t weightIndex,
|
||||
size_t coreId) -> LogicalResult {
|
||||
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||
if (failed(weightView)) {
|
||||
ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
|
||||
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||
}
|
||||
auto globalOp = weightView.globalOp;
|
||||
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
assert(denseAttr && "Weight global must have dense initial value");
|
||||
|
||||
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||
return success();
|
||||
ArrayRef<int64_t> shape = weightView.shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) {
|
||||
mapCoreWeightToFileName[coreId].insert({weight, weightFile->second});
|
||||
return success();
|
||||
}
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
mapWeightValueToFileName[weight] = fileName;
|
||||
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||
return success();
|
||||
}
|
||||
|
||||
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||
ArrayRef<int64_t> shape = weightView->shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
if (globalOp)
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapWeightValueToFileName[weight] = newFileName;
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
return success();
|
||||
};
|
||||
|
||||
auto processCoreLike = [&](auto coreLikeOp) {
|
||||
auto usedIndices = getUsedWeightIndices(coreLikeOp);
|
||||
for (unsigned index : usedIndices) {
|
||||
if (index >= coreLikeOp.getWeights().size()) {
|
||||
coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(coreLikeOp)>, pim::PimCoreOp>) {
|
||||
size_t coreId = static_cast<size_t>(coreLikeOp.getCoreId());
|
||||
for (unsigned index : usedIndices)
|
||||
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
else {
|
||||
auto batchCoreIds = getBatchCoreIds(coreLikeOp);
|
||||
SmallVector<size_t> orderedCoreIds;
|
||||
llvm::SmallSet<size_t, 8> seenCoreIds;
|
||||
for (int32_t coreId : batchCoreIds)
|
||||
if (seenCoreIds.insert(static_cast<size_t>(coreId)).second)
|
||||
orderedCoreIds.push_back(static_cast<size_t>(coreId));
|
||||
|
||||
for (size_t coreId : orderedCoreIds)
|
||||
for (unsigned index : usedIndices)
|
||||
if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
(void) processCoreLike(coreOp);
|
||||
continue;
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
(void) processCoreLike(cast<pim::PimCoreBatchOp>(op));
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
materializedWeights.push_back({weightView, newFileName});
|
||||
return newFileName;
|
||||
};
|
||||
|
||||
for (const WeightFileRequest& request : requests) {
|
||||
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
|
||||
coreFiles.reserve(request.weights.size());
|
||||
for (const ResolvedWeightView& weight : request.weights)
|
||||
coreFiles.push_back(materializeWeight(weight));
|
||||
}
|
||||
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user