add pim.vmm verifier and fix vmm lowering

reuse code for subviews
This commit is contained in:
NiccoloN
2026-05-12 15:13:50 +02:00
parent 628dc630a4
commit 4f3570520c
15 changed files with 358 additions and 207 deletions
+11 -8
View File
@@ -24,8 +24,8 @@
#include <string>
#include <utility>
#include "Common/PimCommon.hpp"
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -206,9 +206,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
return iter->second.address + resolvedAddress->byteOffset;
}
void PimAcceleratorMemory::reportHost() {
hostReportRow = hostMem.getReportRow();
}
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
@@ -810,7 +808,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else {
op.emitError("Unsupported codegen for this operation");
InFlightDiagnostic diag = op.emitError()
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
diag << " inside pim.core " << coreOp.getCoreId();
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
return failure();
}
processedOperations++;
@@ -935,9 +938,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false))
return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
.getReportRow());
memory.recordCoreReport(
emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId()))).getReportRow());
continue;
}
+1 -1
View File
@@ -29,7 +29,7 @@ struct MemoryReportRow {
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
&& sizeGlobal == other.sizeGlobal;
}
};
+4 -15
View File
@@ -10,6 +10,8 @@
#include <cassert>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
@@ -30,19 +32,6 @@ struct DenseWeightView {
int64_t offset = 0;
};
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
strides[index] = strides[index + 1] * shape[index + 1];
return strides;
}
bool allStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
@@ -55,7 +44,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!allStaticSubviewParts(subview))
if (!hasAllStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
@@ -79,7 +68,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStridesForShape(view.shape);
view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;