constant fold linalg.map (generated from tensor.pad for padding)
refactor pim helpers in PimCommon
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -382,12 +381,9 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation and compute its inverse
|
||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||
SmallVector<int64_t> perm =
|
||||
map_to_vector(transposeOp.getPerms().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||
SmallVector<int64_t> permInv(rank);
|
||||
for (size_t i = 0; i < rank; i++)
|
||||
permInv[perm[i]] = i;
|
||||
|
||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||
SmallVector<int64_t> dstShape(rank);
|
||||
@@ -412,10 +408,10 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
remaining %= srcStrides[d];
|
||||
}
|
||||
|
||||
// Compute flat destination index: dstIdx[d] = srcIdx[permInv[d]]
|
||||
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
|
||||
size_t dstFlat = 0;
|
||||
for (size_t d = 0; d < rank; d++)
|
||||
dstFlat += srcIdx[permInv[d]] * dstStrides[d];
|
||||
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||
|
||||
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user