Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
This commit is contained in:
@@ -168,8 +168,8 @@ Each validation run writes artifacts in the model workspace, for example under
|
|||||||
|
|
||||||
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
|
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
|
||||||
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
|
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
|
||||||
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and
|
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
|
||||||
`pim4_materialized.mlir` when an output directory is available.
|
available.
|
||||||
|
|
||||||
To rerun the simulator manually with tracing after validation has produced a
|
To rerun the simulator manually with tracing after validation has produced a
|
||||||
`raptor/pim/` directory:
|
`raptor/pim/` directory:
|
||||||
|
|||||||
@@ -123,7 +123,6 @@ add_pim_library(OMPIMAccel
|
|||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
OMPimMemoryCoalescing
|
OMPimMemoryCoalescing
|
||||||
OMPimHostConstantFolding
|
OMPimHostConstantFolding
|
||||||
OMPimHostConstantMaterialization
|
|
||||||
OMPimVerification
|
OMPimVerification
|
||||||
MLIRTensorInferTypeOpInterfaceImpl
|
MLIRTensorInferTypeOpInterfaceImpl
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
|||||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
int64_t offset = 0;
|
||||||
|
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||||
|
return mlir::failure();
|
||||||
|
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
|
||||||
|
return mlir::failure();
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename VMMOpTy, typename ParentOpTy>
|
template <typename VMMOpTy, typename ParentOpTy>
|
||||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||||
@@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
|||||||
mlir::Value current = weight;
|
mlir::Value current = weight;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
|
||||||
|
current = directAlias;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto defOp = current.getDefiningOp()) {
|
if (auto defOp = current.getDefiningOp()) {
|
||||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||||
@@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
|||||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||||
llvm::SmallVector<int64_t> nextStrides;
|
|
||||||
nextStrides.reserve(subview.getMixedOffsets().size());
|
|
||||||
for (auto [offset, stride, sourceStride] :
|
for (auto [offset, stride, sourceStride] :
|
||||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||||
@@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
}
|
}
|
||||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||||
nextStrides.push_back(stride * sourceStride);
|
|
||||||
}
|
}
|
||||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
|
||||||
view.strides = std::move(nextStrides);
|
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||||
|
if (failed(resultStrides))
|
||||||
|
return mlir::failure();
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = std::move(*resultStrides);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||||
if (view.strides != computeRowMajorStrides(view.shape))
|
|
||||||
return mlir::failure();
|
|
||||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||||
|
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||||
|
if (failed(resultStrides))
|
||||||
|
return mlir::failure();
|
||||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
view.strides = computeRowMajorStrides(view.shape);
|
view.strides = std::move(*resultStrides);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||||
if (view.strides != computeRowMajorStrides(view.shape))
|
|
||||||
return mlir::failure();
|
|
||||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||||
|
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||||
|
if (failed(resultStrides))
|
||||||
|
return mlir::failure();
|
||||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
view.strides = computeRowMajorStrides(view.shape);
|
view.strides = std::move(*resultStrides);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
|
||||||
|
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
|
||||||
|
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||||
|
if (failed(resultStrides))
|
||||||
|
return mlir::failure();
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = std::move(*resultStrides);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlir::failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||||
@@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
|||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
|
||||||
viewOps.push_back(defOp);
|
viewOps.push_back(defOp);
|
||||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
current = subview.getSource();
|
||||||
current = subview.getSource();
|
continue;
|
||||||
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
}
|
||||||
current = collapse.getSrc();
|
|
||||||
else
|
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
|
||||||
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
viewOps.push_back(defOp);
|
||||||
|
current = collapse.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
|
||||||
|
viewOps.push_back(defOp);
|
||||||
|
current = expand.getSrc();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||||
|
viewOps.push_back(defOp);
|
||||||
current = castOp.getSource();
|
current = castOp.getSource();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
|
||||||
|
current = loopAlias;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||||
if (!weightIndex)
|
if (!weightIndex)
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ struct CappedDiagnosticReporter {
|
|||||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void noteFailures(int64_t count) { numFailures += count; }
|
||||||
|
|
||||||
bool hasFailure() const { return numFailures != 0; }
|
bool hasFailure() const { return numFailures != 0; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils
|
|||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
OMPimMemoryCoalescing
|
OMPimMemoryCoalescing
|
||||||
OMPimHostConstantFolding
|
OMPimHostConstantFolding
|
||||||
OMPimHostConstantMaterialization
|
|
||||||
OMPimVerification
|
OMPimVerification
|
||||||
OMPimPasses
|
OMPimPasses
|
||||||
OMONNXToSpatial
|
OMONNXToSpatial
|
||||||
|
|||||||
@@ -32,6 +32,7 @@
|
|||||||
|
|
||||||
#include "Common/IR/CompactAsmUtils.hpp"
|
#include "Common/IR/CompactAsmUtils.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
|
#include "Common/Support/Diagnostics.hpp"
|
||||||
#include "Common/Support/CheckedArithmetic.hpp"
|
#include "Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "Common/Support/ReportUtils.hpp"
|
#include "Common/Support/ReportUtils.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
@@ -996,12 +997,44 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct CoreEmissionResult {
|
struct CoreEmissionResult {
|
||||||
|
static constexpr size_t kMaxStoredCodegenDiagnostics = 8;
|
||||||
|
|
||||||
|
struct DiagnosticRecord {
|
||||||
|
Operation* op = nullptr;
|
||||||
|
std::string message;
|
||||||
|
};
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
|
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
|
||||||
MemoryReportRow reportRow;
|
MemoryReportRow reportRow;
|
||||||
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
|
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
|
||||||
MemoryPlanArtifacts livenessArtifacts;
|
MemoryPlanArtifacts livenessArtifacts;
|
||||||
|
llvm::SmallVector<DiagnosticRecord, kMaxStoredCodegenDiagnostics> diagnostics;
|
||||||
|
size_t diagnosticCount = 0;
|
||||||
|
|
||||||
|
void recordDiagnostic(Operation* op, StringRef message) {
|
||||||
|
++diagnosticCount;
|
||||||
|
if (diagnostics.size() < kMaxStoredCodegenDiagnostics)
|
||||||
|
diagnostics.push_back({op, message.str()});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) {
|
||||||
|
StaticValueKnowledge knowledge;
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||||
|
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
|
static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||||
|
StaticValueKnowledge knowledge;
|
||||||
|
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||||
|
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||||
|
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||||
|
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename MapTy>
|
template <typename MapTy>
|
||||||
class ScopedMapBindings {
|
class ScopedMapBindings {
|
||||||
using KeyTy = typename MapTy::key_type;
|
using KeyTy = typename MapTy::key_type;
|
||||||
@@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
|
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
|
||||||
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
|
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
|
||||||
if (failed(weightView)) {
|
if (failed(weightView)) {
|
||||||
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
|
std::string message;
|
||||||
|
llvm::raw_string_ostream os(message);
|
||||||
|
os << "requires a statically resolvable dense global weight view during PIM codegen; weight="
|
||||||
|
<< vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
|
||||||
|
result.recordDiagnostic(vmmOp, os.str());
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (weightView->shape.size() != 2) {
|
||||||
|
std::string message;
|
||||||
|
llvm::raw_string_ostream os(message);
|
||||||
|
os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=[";
|
||||||
|
llvm::interleaveComma(weightView->shape, os);
|
||||||
|
os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
|
||||||
|
result.recordDiagnostic(vmmOp, os.str());
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||||
deviceMemory.allocateCore(coreOp);
|
deviceMemory.allocateCore(coreOp);
|
||||||
|
|
||||||
int64_t processedOperations = codeGenCoreOps(
|
StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp);
|
||||||
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
|
int64_t processedOperations =
|
||||||
|
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot);
|
||||||
if (processedOperations < 0) {
|
if (processedOperations < 0) {
|
||||||
result.status = CompilerFailure;
|
result.status = CompilerFailure;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
assert(processedOperations > 0);
|
|
||||||
result.reportRow = deviceMemory.getReportRow();
|
result.reportRow = deviceMemory.getReportRow();
|
||||||
result.usedWeights = std::move(usedWeights);
|
result.usedWeights = std::move(usedWeights);
|
||||||
result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
|
result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
|
||||||
@@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||||
|
|
||||||
for (unsigned lane : job.lanes) {
|
for (unsigned lane : job.lanes) {
|
||||||
StaticValueKnowledge knowledge;
|
StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane);
|
||||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
|
||||||
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
|
|
||||||
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
|
|
||||||
|
|
||||||
deviceMemory.allocateCore(coreBatchOp, lane);
|
deviceMemory.allocateCore(coreBatchOp, lane);
|
||||||
coreCodeGen.setBatchLane(lane);
|
coreCodeGen.setBatchLane(lane);
|
||||||
@@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
result.status = CompilerFailure;
|
result.status = CompilerFailure;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
assert(processedOperations > 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result.reportRow = deviceMemory.getReportRow();
|
result.reportRow = deviceMemory.getReportRow();
|
||||||
@@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
mlir::parallelFor(
|
mlir::parallelFor(
|
||||||
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
|
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
|
||||||
|
|
||||||
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
Operation* summaryAnchor = nullptr;
|
||||||
|
for (const CoreEmissionResult& result : jobResults) {
|
||||||
|
if (!summaryAnchor && !result.diagnostics.empty())
|
||||||
|
summaryAnchor = result.diagnostics.front().op;
|
||||||
|
for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) {
|
||||||
|
diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; });
|
||||||
|
}
|
||||||
|
size_t unreportedCount = result.diagnosticCount - result.diagnostics.size();
|
||||||
|
diagnostics.noteFailures(static_cast<int64_t>(unreportedCount));
|
||||||
|
}
|
||||||
|
if (diagnostics.hasFailure())
|
||||||
|
diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(),
|
||||||
|
"PIM codegen diagnostic(s)");
|
||||||
|
|
||||||
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
|
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
|
||||||
if (jobResults[jobIndex].status != CompilerSuccess)
|
if (jobResults[jobIndex].status != CompilerSuccess)
|
||||||
return jobResults[jobIndex].status;
|
return jobResults[jobIndex].status;
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimHostConstantFoldingPass());
|
pm.addPass(createPimHostConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
|
||||||
pm.addPass(createPimMemoryCoalescingPass());
|
pm.addPass(createPimMemoryCoalescingPass());
|
||||||
pm.addPass(createPimVerificationPass());
|
pm.addPass(createPimVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
|
|||||||
@@ -375,6 +375,57 @@ static void cloneHelperChain(Value sourceValue,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isHostStaticReturnValue(Value value) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return false;
|
||||||
|
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
return true;
|
||||||
|
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
|
||||||
|
return false;
|
||||||
|
value = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
materializeHostStaticReturnValue(IRRewriter& rewriter, Value value, OperationFolder& constantFolder) {
|
||||||
|
llvm::SmallVector<Operation*> chain;
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return failure();
|
||||||
|
chain.push_back(definingOp);
|
||||||
|
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
break;
|
||||||
|
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
|
||||||
|
return failure();
|
||||||
|
value = definingOp->getOperand(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chain.empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
IRMapping mapping;
|
||||||
|
Value clonedValue;
|
||||||
|
for (Operation* op : llvm::reverse(chain)) {
|
||||||
|
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
|
||||||
|
clonedValue = getOrCreateConstantLike(constantFolder, constantOp);
|
||||||
|
mapping.map(op->getResult(0), clonedValue);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
clonedValue = clonedOp->getResult(0);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return clonedValue;
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
|
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
|
||||||
Location loc,
|
Location loc,
|
||||||
Value outputTensor,
|
Value outputTensor,
|
||||||
@@ -444,7 +495,30 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
OperationFolder constantFolder(producerOp->getContext());
|
OperationFolder constantFolder(producerOp->getContext());
|
||||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||||
|
|
||||||
|
auto materializeDirectHostReturn = [&](size_t returnIndex,
|
||||||
|
Value sourceValue,
|
||||||
|
ArrayRef<Operation*> helperChain) -> ReturnPathLoweringResult {
|
||||||
|
rewriter.setInsertionPointAfter(producerOp);
|
||||||
|
auto hostStaticValue = materializeHostStaticReturnValue(rewriter, sourceValue, constantFolder);
|
||||||
|
if (failed(hostStaticValue))
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
|
||||||
|
Value hostReturnValue = *hostStaticValue;
|
||||||
|
if (!helperChain.empty())
|
||||||
|
cloneHelperChain(hostReturnValue, helperChain, rewriter, constantFolder, hostReturnValue);
|
||||||
|
|
||||||
|
outputTensors[returnIndex] =
|
||||||
|
[hostReturnValue](IRRewriter& rewriter, Location loc) -> Value { return hostReturnValue; };
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
};
|
||||||
|
|
||||||
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||||
|
if (isHostStaticReturnValue(storedValue)) {
|
||||||
|
for (Operation* op : returnUse->helperChain)
|
||||||
|
markOpToRemove(op);
|
||||||
|
return materializeDirectHostReturn(returnUse->returnIndex, storedValue, returnUse->helperChain);
|
||||||
|
}
|
||||||
|
|
||||||
Value currentStoredValue = storedValue;
|
Value currentStoredValue = storedValue;
|
||||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||||
for (Operation* op : returnUse->helperChain)
|
for (Operation* op : returnUse->helperChain)
|
||||||
@@ -470,6 +544,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
|
|
||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
|
if (isHostStaticReturnValue(storedValue))
|
||||||
|
return materializeDirectHostReturn(resultIndexInReturn, storedValue, {});
|
||||||
auto byteSize =
|
auto byteSize =
|
||||||
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
|
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
|
||||||
if (failed(byteSize))
|
if (failed(byteSize))
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
|
|||||||
add_subdirectory(Transforms/Bufferization)
|
add_subdirectory(Transforms/Bufferization)
|
||||||
add_subdirectory(Transforms/MemoryCoalescing)
|
add_subdirectory(Transforms/MemoryCoalescing)
|
||||||
add_subdirectory(Transforms/HostConstantFolding)
|
add_subdirectory(Transforms/HostConstantFolding)
|
||||||
add_subdirectory(Transforms/HostConstantMaterialization)
|
|
||||||
add_subdirectory(Transforms/Verification)
|
add_subdirectory(Transforms/Verification)
|
||||||
|
|
||||||
add_pim_library(PimOps
|
add_pim_library(PimOps
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace bufferization;
|
using namespace bufferization;
|
||||||
@@ -13,7 +14,9 @@ using namespace bufferization;
|
|||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
bool isContiguous =
|
||||||
|
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||||
|
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
|||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
if (isHostBackedPimAddress(memrefValue)) {
|
||||||
|
return PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
return PimMemCopyOp::create(
|
return PimMemCopyOp::create(
|
||||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
bool isContiguous =
|
||||||
|
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||||
|
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
|
|||||||
@@ -1,9 +1,70 @@
|
|||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
static bool isCoreBatchInputArgument(Value value) {
|
||||||
|
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||||
|
if (!blockArg)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
|
||||||
|
if (!coreBatchOp)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
|
||||||
|
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
|
||||||
|
llvm::SmallPtrSet<Value, 8> visited;
|
||||||
|
while (value && visited.insert(value).second) {
|
||||||
|
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||||
|
if (alias)
|
||||||
|
value = alias;
|
||||||
|
|
||||||
|
if (auto aliased = knowledge.aliases.lookup(value)) {
|
||||||
|
value = aliased;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
|
||||||
|
return base;
|
||||||
|
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value)
|
||||||
|
return value;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
|
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
|
||||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||||
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
|
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
|
||||||
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
|
|||||||
return failure();
|
return failure();
|
||||||
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
|
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||||
|
if (alias)
|
||||||
|
value = alias;
|
||||||
|
|
||||||
|
auto resolved = resolveContiguousAddress(value, knowledge);
|
||||||
|
if (succeeded(resolved))
|
||||||
|
return resolved->base;
|
||||||
|
|
||||||
|
auto compiled = compileContiguousAddressExpr(value);
|
||||||
|
if (failed(compiled)) {
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return value;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return compiled->base;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
auto base = getPimStorageBase(value, knowledge);
|
||||||
|
if (failed(base))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (isCoreBatchInputArgument(*base))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
auto base = getPimStorageBase(value, knowledge);
|
||||||
|
if (failed(base))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,11 +2,19 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
mlir::FailureOr<mlir::IntegerAttr>
|
mlir::FailureOr<mlir::IntegerAttr>
|
||||||
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
|
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||||
|
|
||||||
|
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||||
|
|
||||||
|
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||||
|
|
||||||
} // namespace pim
|
} // namespace pim
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Rewrite/PatternApplicator.h"
|
#include "mlir/Rewrite/PatternApplicator.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
@@ -15,6 +16,7 @@
|
|||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
@@ -27,24 +29,71 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
struct MemRefCopyWorkItem {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
memref::CopyOp copyOp;
|
||||||
|
StaticValueKnowledge knowledge;
|
||||||
|
};
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
|
||||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
StaticValueKnowledge knowledge;
|
||||||
return failure();
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||||
|
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
StaticValueKnowledge knowledge;
|
||||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||||
return failure();
|
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||||
if (sourceType.getElementType() != targetType.getElementType())
|
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||||
return failure();
|
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||||
|
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
static LogicalResult
|
||||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
|
||||||
if (failed(sizeAttr))
|
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||||
|
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||||
|
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != targetType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||||
|
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||||
|
if (failed(sizeAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
|
||||||
|
if (targetIsDevice && sourceIsHost) {
|
||||||
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
copyOp.getTarget().getType(),
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
|
copyOp.getTarget(),
|
||||||
|
copyOp.getSource(),
|
||||||
|
*sizeAttr);
|
||||||
|
}
|
||||||
|
else if (targetIsHost && sourceIsDevice) {
|
||||||
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
copyOp.getTarget().getType(),
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
|
copyOp.getTarget(),
|
||||||
|
copyOp.getSource(),
|
||||||
|
*sizeAttr);
|
||||||
|
}
|
||||||
|
else if (targetIsDevice && sourceIsDevice) {
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
copyOp.getLoc(),
|
copyOp.getLoc(),
|
||||||
copyOp.getTarget().getType(),
|
copyOp.getTarget().getType(),
|
||||||
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
|||||||
copyOp.getTarget(),
|
copyOp.getTarget(),
|
||||||
copyOp.getSource(),
|
copyOp.getSource(),
|
||||||
*sizeAttr);
|
*sizeAttr);
|
||||||
rewriter.eraseOp(copyOp);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
else {
|
||||||
|
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
|
||||||
|
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
|
||||||
|
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
|
||||||
|
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
|
||||||
|
<< " device=" << targetIsDevice;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(copyOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
RewritePatternSet memrefCopyPatterns(ctx);
|
|
||||||
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
|
|
||||||
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
|
|
||||||
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
|
|
||||||
memrefCopyApplicator.applyDefaultCostModel();
|
|
||||||
PatternRewriter rewriter(ctx);
|
PatternRewriter rewriter(ctx);
|
||||||
|
|
||||||
SmallVector<memref::CopyOp> copyWorklist;
|
SmallVector<MemRefCopyWorkItem> copyWorklist;
|
||||||
moduleOp.walk([&](memref::CopyOp copyOp) {
|
llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
|
||||||
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
copyWorklist.push_back(copyOp);
|
if (seenCopyOps.insert(copyOp.getOperation()).second)
|
||||||
|
copyWorklist.push_back({copyOp, knowledge});
|
||||||
|
};
|
||||||
|
|
||||||
|
moduleOp.walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
|
||||||
|
(void) walkPimCoreBlockStructurally(
|
||||||
|
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||||
|
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||||
|
addCopyOp(copyOp, opKnowledge);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
llvm::SmallVector<unsigned, 2> lanes;
|
||||||
|
lanes.push_back(0);
|
||||||
|
if (coreBatchOp.getLaneCount() > 1)
|
||||||
|
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
|
||||||
|
for (unsigned lane : lanes) {
|
||||||
|
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
|
||||||
|
(void) walkPimCoreBlockStructurally(
|
||||||
|
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||||
|
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||||
|
addCopyOp(copyOp, opKnowledge);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
bool hasFailed = false;
|
bool hasFailed = false;
|
||||||
for (memref::CopyOp copyOp : copyWorklist) {
|
for (const MemRefCopyWorkItem& workItem : copyWorklist) {
|
||||||
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
|
memref::CopyOp copyOp = workItem.copyOp;
|
||||||
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
|
rewriter.setInsertionPoint(copyOp);
|
||||||
|
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
|
||||||
hasFailed = true;
|
hasFailed = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (hasFailed) {
|
if (hasFailed) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
pim::PimMemCopyOp::create(
|
pim::PimMemCopyHostToDevOp::create(
|
||||||
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
|
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
|
||||||
rewriter.eraseOp(mapOp);
|
rewriter.eraseOp(mapOp);
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
add_pim_library(OMPimHostConstantMaterialization
|
|
||||||
MaterializeHostConstantsPass.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
OMPimCommon
|
|
||||||
PimOps
|
|
||||||
)
|
|
||||||
-161
@@ -1,161 +0,0 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/IR/Dominance.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/Support/MathExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename CoreOpTy>
|
|
||||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|
||||||
IRRewriter& rewriter,
|
|
||||||
OperationFolder& constantFolder,
|
|
||||||
bool& hasFailure) {
|
|
||||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
|
||||||
DominanceInfo dominance(coreOp);
|
|
||||||
SmallVector<Operation*> ops;
|
|
||||||
coreOp.getBody().front().walk([&](Operation* op) {
|
|
||||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
|
||||||
ops.push_back(op);
|
|
||||||
});
|
|
||||||
|
|
||||||
for (Operation* op : ops) {
|
|
||||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (OpOperand& operand : op->getOpOperands()) {
|
|
||||||
Value originalValue = operand.get();
|
|
||||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
|
||||||
if (failed(resolvedAddress))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|
||||||
if (!getGlobalOp)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
|
||||||
if (!originalType || !originalType.hasStaticShape()) {
|
|
||||||
op->emitOpError("host constant materialization requires a static memref operand");
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
|
||||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
|
||||||
auto cachedValue = cachedByType.find(originalType);
|
|
||||||
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
|
||||||
operand.set(cachedValue->second);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto type = dyn_cast<ShapedType>(originalValue.getType());
|
|
||||||
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
|
|
||||||
: FailureOr<uint64_t>(failure());
|
|
||||||
auto totalBytesAttr =
|
|
||||||
succeeded(totalBytes)
|
|
||||||
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
|
|
||||||
: FailureOr<IntegerAttr>(failure());
|
|
||||||
if (failed(totalBytesAttr)
|
|
||||||
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
|
||||||
Value deviceDst = localAlloc;
|
|
||||||
if (contiguousType != originalType)
|
|
||||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
|
||||||
|
|
||||||
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
|
||||||
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
|
||||||
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
originalType,
|
|
||||||
zeroOffset,
|
|
||||||
hostOffset,
|
|
||||||
deviceDst,
|
|
||||||
getGlobalOp.getResult(),
|
|
||||||
*totalBytesAttr)
|
|
||||||
.getOutput();
|
|
||||||
|
|
||||||
cachedByType[originalType] = copiedValue;
|
|
||||||
operand.set(copiedValue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
|
||||||
|
|
||||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
|
||||||
StringRef getDescription() const override {
|
|
||||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
ModuleOp moduleOp = getOperation();
|
|
||||||
IRRewriter rewriter(moduleOp.getContext());
|
|
||||||
OperationFolder constantFolder(moduleOp.getContext());
|
|
||||||
bool hasFailure = false;
|
|
||||||
|
|
||||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
|
||||||
if (funcOp.isExternal())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
|
||||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
|
||||||
|
|
||||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
|
||||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
|
||||||
|
|
||||||
SmallVector<Operation*> hostCompactOps;
|
|
||||||
for (Operation& op : funcOp.getBody().front())
|
|
||||||
if (isa<pim::PimConcatOp>(op))
|
|
||||||
hostCompactOps.push_back(&op);
|
|
||||||
|
|
||||||
for (Operation* op : hostCompactOps) {
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
auto concatOp = cast<pim::PimConcatOp>(op);
|
|
||||||
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
|
||||||
hasFailure = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasFailure) {
|
|
||||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dumpModule(moduleOp, "pim4_materialized");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
|
||||||
return std::make_unique<MaterializeHostConstantsPass>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
|
||||||
if (succeeded(resolvedAddress))
|
|
||||||
return isa<BlockArgument>(resolvedAddress->base)
|
|
||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|
||||||
|
|
||||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
|
||||||
if (failed(compiledAddress))
|
|
||||||
return false;
|
|
||||||
return isa<BlockArgument>(compiledAddress->base)
|
|
||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isConstantGlobalView(Value value) {
|
static bool isConstantGlobalView(Value value) {
|
||||||
while (true) {
|
while (true) {
|
||||||
Operation* defOp = value.getDefiningOp();
|
Operation* defOp = value.getDefiningOp();
|
||||||
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
|||||||
memref::GetGlobalOp>(op);
|
memref::GetGlobalOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||||
|
Value base;
|
||||||
|
if (succeeded(resolvedAddress)) {
|
||||||
|
base = resolvedAddress->base;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||||
|
if (failed(compiledAddress))
|
||||||
|
return false;
|
||||||
|
base = compiledAddress->base;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<BlockArgument>(base))
|
||||||
|
return true;
|
||||||
|
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||||
|
|
||||||
@@ -311,10 +316,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
|
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
|
||||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
if (!isHostAddressableValue(operand, knowledge)) {
|
||||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||||
<< " is not backed by contiguous addressable storage";
|
<< " must be backed by host-addressable storage";
|
||||||
});
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createEmitPimCodePass();
|
std::unique_ptr<mlir::Pass> createEmitPimCodePass();
|
||||||
|
|||||||
@@ -78,7 +78,6 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createPimMemoryCoalescingPass);
|
registerPass(createPimMemoryCoalescingPass);
|
||||||
registerPass(createMergeComputeNodesPass);
|
registerPass(createMergeComputeNodesPass);
|
||||||
registerPass(createPimHostConstantFoldingPass);
|
registerPass(createPimHostConstantFoldingPass);
|
||||||
registerPass(createPimMaterializeHostConstantsPass);
|
|
||||||
registerPass(createPimVerificationPass);
|
registerPass(createPimVerificationPass);
|
||||||
registerPass(createEmitPimCodePass);
|
registerPass(createEmitPimCodePass);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ PIM_PASS_LABELS = (
|
|||||||
("SpatialToPimPass", "Spatial to PIM"),
|
("SpatialToPimPass", "Spatial to PIM"),
|
||||||
("PimBufferizationPass", "Bufferize PIM"),
|
("PimBufferizationPass", "Bufferize PIM"),
|
||||||
("HostConstantFoldingPass", "Fold Host Constants"),
|
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||||
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
|
||||||
("VerificationPass", "Verify PIM"),
|
("VerificationPass", "Verify PIM"),
|
||||||
("EmitPimCodePass", "Emit PIM Code"),
|
("EmitPimCodePass", "Emit PIM Code"),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user