add pim.vmm verifier and fix vmm lowering
reuse code for subviews
This commit is contained in:
@@ -24,8 +24,8 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
@@ -206,9 +206,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::reportHost() {
|
||||
hostReportRow = hostMem.getReportRow();
|
||||
}
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||
@@ -810,7 +808,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
InFlightDiagnostic diag = op.emitError()
|
||||
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
|
||||
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
|
||||
diag << " inside pim.core " << coreOp.getCoreId();
|
||||
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
|
||||
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
|
||||
return failure();
|
||||
}
|
||||
processedOperations++;
|
||||
@@ -935,9 +938,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
return err;
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||
.getReportRow());
|
||||
memory.recordCoreReport(
|
||||
emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId()))).getReportRow());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ struct MemoryReportRow {
|
||||
|
||||
bool operator==(const MemoryReportRow& other) const {
|
||||
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
|
||||
&& sizeGlobal == other.sizeGlobal;
|
||||
&& sizeGlobal == other.sizeGlobal;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
@@ -30,19 +32,6 @@ struct DenseWeightView {
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||
strides[index] = strides[index + 1] * shape[index + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
mlir::Value current = weight;
|
||||
@@ -55,7 +44,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
current = subview.getSource();
|
||||
@@ -79,7 +68,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
|
||||
Reference in New Issue
Block a user