add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
@@ -13,11 +13,12 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
@@ -49,8 +50,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (!getGlobalOp->hasAttr("weightAlways")) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
if (!hasWeightAlways(getGlobalOp)) {
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
auto iter = globalConstants.find(globalMemrefOp);
|
||||
if (iter == globalConstants.end())
|
||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
||||
@@ -81,7 +82,7 @@ MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||
}
|
||||
|
||||
@@ -112,10 +113,33 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
}
|
||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
}
|
||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
return memEntriesMap.at(value).address + offset;
|
||||
|
||||
auto iter = memEntriesMap.find(value);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
value.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + offset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
@@ -348,6 +372,55 @@ void PimCodeGen::codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) co
|
||||
}
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
auto srcAddr = memory.getValueAddress(transposeOp.getData());
|
||||
auto dstAddr = memory.getValueAddress(transposeOp.getOutBuf());
|
||||
|
||||
auto srcType = cast<ShapedType>(transposeOp.getData().getType());
|
||||
auto srcShape = srcType.getShape();
|
||||
size_t rank = srcShape.size();
|
||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation and compute its inverse
|
||||
SmallVector<int64_t> perm =
|
||||
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||
SmallVector<int64_t> permInv(rank);
|
||||
for (size_t i = 0; i < rank; i++)
|
||||
permInv[perm[i]] = i;
|
||||
|
||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||
SmallVector<int64_t> dstShape(rank);
|
||||
for (size_t i = 0; i < rank; i++)
|
||||
dstShape[i] = srcShape[perm[i]];
|
||||
|
||||
// Row-major strides for source and destination
|
||||
SmallVector<size_t> srcStrides(rank, 1);
|
||||
SmallVector<size_t> dstStrides(rank, 1);
|
||||
for (int64_t i = rank - 2; i >= 0; i--) {
|
||||
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
|
||||
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
||||
}
|
||||
|
||||
// Emit element-by-element copy with transposed addressing
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
// Decompose flat source index into multi-dimensional index
|
||||
SmallVector<size_t> srcIdx(rank);
|
||||
size_t remaining = srcFlat;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
srcIdx[d] = remaining / srcStrides[d];
|
||||
remaining %= srcStrides[d];
|
||||
}
|
||||
|
||||
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
|
||||
size_t dstFlat = 0;
|
||||
for (size_t d = 0; d < rank; d++)
|
||||
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
|
||||
|
||||
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
||||
}
|
||||
}
|
||||
|
||||
size_t getMatrixSize(ShapedType matrixShape) {
|
||||
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
|
||||
assert(false && "Unsupported matrix shape");
|
||||
@@ -378,9 +451,9 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (getGlobalOp->hasAttr("weightAlways"))
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
@@ -416,7 +489,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
size_t processedOperations = 0;
|
||||
for (auto& op : coreOp.getBody().front()) {
|
||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp>(op))
|
||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||
@@ -435,6 +508,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
||||
else if (auto applyFiltersOp = dyn_cast<pim::PimApplyFiltersOp>(op))
|
||||
coreCodeGen.codeGenApplyFiltersOp(applyFiltersOp);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||
else if (auto vaddOp = dyn_cast<pim::PimVAddOp>(op))
|
||||
coreCodeGen.codeGenVAddOp(vaddOp);
|
||||
else if (auto vmaxOp = dyn_cast<pim::PimVMaxOp>(op))
|
||||
@@ -475,7 +550,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
||||
continue;
|
||||
}
|
||||
|
||||
auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(moduleOp, getGlobalOp.getNameAttr());
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
|
||||
weightIndex++;
|
||||
@@ -589,9 +664,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
}
|
||||
}
|
||||
|
||||
auto funcOps = moduleOp.getOps<func::FuncOp>();
|
||||
assert(!funcOps.empty() && "No function found in the module");
|
||||
auto funcOp = *funcOps.begin();
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc))
|
||||
return CompilerFailure;
|
||||
auto funcOp = *entryFunc;
|
||||
|
||||
PimAcceleratorMemory memory;
|
||||
memory.hostMem.allocateHost(moduleOp, funcOp);
|
||||
|
||||
Reference in New Issue
Block a user