refactor Pim constant folding pass
share contiguous address resolution in PimCommon group patterns in subdir for each pass with pattern files
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -86,48 +86,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
}
|
||||
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
size_t offset = 0;
|
||||
while (true) {
|
||||
auto definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
break;
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
break;
|
||||
value = tiedOperand->get();
|
||||
}
|
||||
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto source = subviewDefiningOp.getSource();
|
||||
auto srcShape = source.getType().getShape();
|
||||
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
|
||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
||||
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
||||
size_t localOffset = subviewOffsets[i];
|
||||
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
||||
localOffset *= subviewSizes[j];
|
||||
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
}
|
||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
}
|
||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(value);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress)) {
|
||||
errs() << "Failed to resolve contiguous address for value: ";
|
||||
value.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
@@ -135,10 +96,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Failed to resolve contiguous address");
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(resolvedAddress->base);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
resolvedAddress->base.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + offset;
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
|
||||
Reference in New Issue
Block a user