constant fold linalg.map (generated from tensor.pad for padding)

refactor pim helpers in PimCommon
This commit is contained in:
NiccoloN
2026-03-20 20:51:20 +01:00
parent dbe646ac0d
commit 6933804003
14 changed files with 751 additions and 263 deletions

View File

@@ -15,7 +15,6 @@
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -382,12 +381,9 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation and compute its inverse
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm =
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
SmallVector<int64_t> permInv(rank);
for (size_t i = 0; i < rank; i++)
permInv[perm[i]] = i;
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
@@ -412,10 +408,10 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}