Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone

This commit is contained in:
ilgeco
2026-06-05 10:20:09 +02:00
20 changed files with 458 additions and 256 deletions
+2 -2
View File
@@ -168,8 +168,8 @@ Each validation run writes artifacts in the model workspace, for example under
The compiler currently dumps dialect snapshots such as `spatial0.mlir`, The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`, `spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and `pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
`pim4_materialized.mlir` when an output directory is available. available.
To rerun the simulator manually with tracing after validation has produced a To rerun the simulator manually with tracing after validation has produced a
`raptor/pim/` directory: `raptor/pim/` directory:
-1
View File
@@ -123,7 +123,6 @@ add_pim_library(OMPIMAccel
OMPimBufferization OMPimBufferization
OMPimMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification OMPimVerification
MLIRTensorInferTypeOpInterfaceImpl MLIRTensorInferTypeOpInterfaceImpl
) )
+62 -18
View File
@@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
} }
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
return mlir::failure();
return strides;
}
template <typename VMMOpTy, typename ParentOpTy> template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex); auto weightArg = parentOp.getWeightArgument(weightIndex);
@@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
mlir::Value current = weight; mlir::Value current = weight;
while (true) { while (true) {
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
current = directAlias;
continue;
}
if (auto defOp = current.getDefiningOp()) { if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) { if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {}; auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
@@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
CompiledIndexExpr offsetExpr = makeConstantExpr(0); CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) { for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) { if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
llvm::SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getMixedOffsets().size());
for (auto [offset, stride, sourceStride] : for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) { llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0); CompiledIndexExpr offsetValue = makeConstantExpr(0);
@@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure(); return mlir::failure();
} }
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride)); offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
nextStrides.push_back(stride * sourceStride);
} }
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
view.strides = std::move(nextStrides); auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue; continue;
} }
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) { if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType()); auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape); view.strides = std::move(*resultStrides);
continue; continue;
} }
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) { if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType()); auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape); view.strides = std::move(*resultStrides);
continue;
} }
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
return mlir::failure();
} }
auto resolvedOffset = offsetExpr.evaluate(knowledge); auto resolvedOffset = offsetExpr.evaluate(knowledge);
@@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return view; return view;
} }
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) { if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
viewOps.push_back(defOp); viewOps.push_back(defOp);
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) current = subview.getSource();
current = subview.getSource(); continue;
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) }
current = collapse.getSrc();
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc(); viewOps.push_back(defOp);
current = collapse.getSrc();
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = expand.getSrc();
continue; continue;
} }
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) { if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
viewOps.push_back(defOp);
current = castOp.getSource(); current = castOp.getSource();
continue; continue;
} }
@@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure(); return mlir::failure();
} }
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
current = loopAlias;
continue;
}
auto weightIndex = resolveWeightIndex(weightOwner, current); auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex) if (!weightIndex)
return mlir::failure(); return mlir::failure();
+2
View File
@@ -28,6 +28,8 @@ struct CappedDiagnosticReporter {
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription; op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
} }
void noteFailures(int64_t count) { numFailures += count; }
bool hasFailure() const { return numFailures != 0; } bool hasFailure() const { return numFailures != 0; }
private: private:
-1
View File
@@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils
OMPimBufferization OMPimBufferization
OMPimMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification OMPimVerification
OMPimPasses OMPimPasses
OMONNXToSpatial OMONNXToSpatial
+66 -9
View File
@@ -32,6 +32,7 @@
#include "Common/IR/CompactAsmUtils.hpp" #include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/Support/Diagnostics.hpp"
#include "Common/Support/CheckedArithmetic.hpp" #include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp" #include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -996,12 +997,44 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
} }
struct CoreEmissionResult { struct CoreEmissionResult {
static constexpr size_t kMaxStoredCodegenDiagnostics = 8;
struct DiagnosticRecord {
Operation* op = nullptr;
std::string message;
};
OnnxMlirCompilerErrorCodes status = CompilerSuccess; OnnxMlirCompilerErrorCodes status = CompilerSuccess;
MemoryReportRow reportRow; MemoryReportRow reportRow;
llvm::SmallVector<ResolvedWeightView, 8> usedWeights; llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
MemoryPlanArtifacts livenessArtifacts; MemoryPlanArtifacts livenessArtifacts;
llvm::SmallVector<DiagnosticRecord, kMaxStoredCodegenDiagnostics> diagnostics;
size_t diagnosticCount = 0;
void recordDiagnostic(Operation* op, StringRef message) {
++diagnosticCount;
if (diagnostics.size() < kMaxStoredCodegenDiagnostics)
diagnostics.push_back({op, message.str()});
}
}; };
static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
template <typename MapTy> template <typename MapTy>
class ScopedMapBindings { class ScopedMapBindings {
using KeyTy = typename MapTy::key_type; using KeyTy = typename MapTy::key_type;
@@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> { const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge); auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
if (failed(weightView)) { if (failed(weightView)) {
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen"); std::string message;
llvm::raw_string_ostream os(message);
os << "requires a statically resolvable dense global weight view during PIM codegen; weight="
<< vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure();
}
if (weightView->shape.size() != 2) {
std::string message;
llvm::raw_string_ostream os(message);
os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=[";
llvm::interleaveComma(weightView->shape, os);
os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure(); return failure();
} }
@@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
deviceMemory.allocateCore(coreOp); deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps( StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp);
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot); int64_t processedOperations =
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot);
if (processedOperations < 0) { if (processedOperations < 0) {
result.status = CompilerFailure; result.status = CompilerFailure;
return result; return result;
} }
assert(processedOperations > 0);
result.reportRow = deviceMemory.getReportRow(); result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights); result.usedWeights = std::move(usedWeights);
result.livenessArtifacts = deviceMemory.getLivenessArtifacts(); result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
@@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
for (unsigned lane : job.lanes) { for (unsigned lane : job.lanes) {
StaticValueKnowledge knowledge; StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane);
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
deviceMemory.allocateCore(coreBatchOp, lane); deviceMemory.allocateCore(coreBatchOp, lane);
coreCodeGen.setBatchLane(lane); coreCodeGen.setBatchLane(lane);
@@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
result.status = CompilerFailure; result.status = CompilerFailure;
return result; return result;
} }
assert(processedOperations > 0);
} }
result.reportRow = deviceMemory.getReportRow(); result.reportRow = deviceMemory.getReportRow();
@@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
mlir::parallelFor( mlir::parallelFor(
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); }); moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
pim::CappedDiagnosticReporter diagnostics;
Operation* summaryAnchor = nullptr;
for (const CoreEmissionResult& result : jobResults) {
if (!summaryAnchor && !result.diagnostics.empty())
summaryAnchor = result.diagnostics.front().op;
for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) {
diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; });
}
size_t unreportedCount = result.diagnosticCount - result.diagnostics.size();
diagnostics.noteFailures(static_cast<int64_t>(unreportedCount));
}
if (diagnostics.hasFailure())
diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(),
"PIM codegen diagnostic(s)");
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess) if (jobResults[jobIndex].status != CompilerSuccess)
return jobResults[jobIndex].status; return jobResults[jobIndex].status;
-1
View File
@@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimHostConstantFoldingPass()); pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded")); pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimMemoryCoalescingPass()); pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass()); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
@@ -375,6 +375,57 @@ static void cloneHelperChain(Value sourceValue,
} }
} }
static bool isHostStaticReturnValue(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
while (Operation* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
return false;
value = definingOp->getOperand(0);
}
return false;
}
static FailureOr<Value>
materializeHostStaticReturnValue(IRRewriter& rewriter, Value value, OperationFolder& constantFolder) {
llvm::SmallVector<Operation*> chain;
llvm::SmallPtrSet<Operation*, 8> visited;
while (Operation* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return failure();
chain.push_back(definingOp);
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
break;
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
return failure();
value = definingOp->getOperand(0);
}
if (chain.empty())
return failure();
IRMapping mapping;
Value clonedValue;
for (Operation* op : llvm::reverse(chain)) {
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
clonedValue = getOrCreateConstantLike(constantFolder, constantOp);
mapping.map(op->getResult(0), clonedValue);
continue;
}
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
return clonedValue;
}
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter, static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
Location loc, Location loc,
Value outputTensor, Value outputTensor,
@@ -444,7 +495,30 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
OperationFolder constantFolder(producerOp->getContext()); OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType()); auto storedTensorType = cast<TensorType>(storedValue.getType());
auto materializeDirectHostReturn = [&](size_t returnIndex,
Value sourceValue,
ArrayRef<Operation*> helperChain) -> ReturnPathLoweringResult {
rewriter.setInsertionPointAfter(producerOp);
auto hostStaticValue = materializeHostStaticReturnValue(rewriter, sourceValue, constantFolder);
if (failed(hostStaticValue))
return ReturnPathLoweringResult::Failure;
Value hostReturnValue = *hostStaticValue;
if (!helperChain.empty())
cloneHelperChain(hostReturnValue, helperChain, rewriter, constantFolder, hostReturnValue);
outputTensors[returnIndex] =
[hostReturnValue](IRRewriter& rewriter, Location loc) -> Value { return hostReturnValue; };
return ReturnPathLoweringResult::Handled;
};
if (auto returnUse = analyzeReturnUse(producedValue)) { if (auto returnUse = analyzeReturnUse(producedValue)) {
if (isHostStaticReturnValue(storedValue)) {
for (Operation* op : returnUse->helperChain)
markOpToRemove(op);
return materializeDirectHostReturn(returnUse->returnIndex, storedValue, returnUse->helperChain);
}
Value currentStoredValue = storedValue; Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue); cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain) for (Operation* op : returnUse->helperChain)
@@ -470,6 +544,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
if (isHostStaticReturnValue(storedValue))
return materializeDirectHostReturn(resultIndexInReturn, storedValue, {});
auto byteSize = auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size"); pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
if (failed(byteSize)) if (failed(byteSize))
-1
View File
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization) add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/MemoryCoalescing) add_subdirectory(Transforms/MemoryCoalescing)
add_subdirectory(Transforms/HostConstantFolding) add_subdirectory(Transforms/HostConstantFolding)
add_subdirectory(Transforms/HostConstantMaterialization)
add_subdirectory(Transforms/Verification) add_subdirectory(Transforms/Verification)
add_pim_library(PimOps add_pim_library(PimOps
@@ -6,6 +6,7 @@
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir; using namespace mlir;
using namespace bufferization; using namespace bufferization;
@@ -13,7 +14,9 @@ using namespace bufferization;
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue; return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
if (isHostBackedPimAddress(memrefValue)) {
return PimMemCopyHostToDevOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create( return PimMemCopyOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput(); .getOutput();
} }
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue; return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -1,9 +1,70 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
static bool isCoreBatchInputArgument(Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg)
return false;
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
if (!coreBatchOp)
return false;
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
}
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
llvm::SmallPtrSet<Value, 8> visited;
while (value && visited.insert(value).second) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
if (auto aliased = knowledge.aliases.lookup(value)) {
value = aliased;
continue;
}
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
return base;
if (isa<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
return value;
}
if (value)
return value;
return failure();
}
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) { FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType()); auto type = mlir::cast<MemRefType>(memref.getType());
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size"); auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
return failure(); return failure();
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size"); return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
} }
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
auto resolved = resolveContiguousAddress(value, knowledge);
if (succeeded(resolved))
return resolved->base;
auto compiled = compileContiguousAddressExpr(value);
if (failed(compiled)) {
if (isa<BlockArgument>(value))
return value;
return failure();
}
return compiled->base;
}
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
if (isCoreBatchInputArgument(*base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
}
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
}
@@ -2,11 +2,19 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
mlir::FailureOr<mlir::IntegerAttr> mlir::FailureOr<mlir::IntegerAttr>
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref); getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -8,6 +8,7 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
@@ -15,6 +16,7 @@
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" #include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -27,24 +29,71 @@ namespace onnx_mlir {
namespace { namespace {
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> { struct MemRefCopyWorkItem {
using OpRewritePattern::OpRewritePattern; memref::CopyOp copyOp;
StaticValueKnowledge knowledge;
};
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override { static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>()) StaticValueKnowledge knowledge;
return failure(); for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType()); static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType()); StaticValueKnowledge knowledge;
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape()) knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
return failure(); for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
if (sourceType.getElementType() != targetType.getElementType()) knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
return failure(); for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0); static LogicalResult
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource()); lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
if (failed(sizeAttr)) if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure(); return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
if (targetIsDevice && sourceIsHost) {
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsHost && sourceIsDevice) {
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsDevice && sourceIsDevice) {
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
copyOp.getTarget().getType(), copyOp.getTarget().getType(),
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
copyOp.getTarget(), copyOp.getTarget(),
copyOp.getSource(), copyOp.getSource(),
*sizeAttr); *sizeAttr);
rewriter.eraseOp(copyOp);
return success();
} }
}; else {
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
<< " device=" << targetIsDevice;
return failure();
}
rewriter.eraseOp(copyOp);
return success();
}
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> { struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
} }
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
RewritePatternSet memrefCopyPatterns(ctx);
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
memrefCopyApplicator.applyDefaultCostModel();
PatternRewriter rewriter(ctx); PatternRewriter rewriter(ctx);
SmallVector<memref::CopyOp> copyWorklist; SmallVector<MemRefCopyWorkItem> copyWorklist;
moduleOp.walk([&](memref::CopyOp copyOp) { llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>()) auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
copyWorklist.push_back(copyOp); if (seenCopyOps.insert(copyOp.getOperation()).second)
copyWorklist.push_back({copyOp, knowledge});
};
moduleOp.walk([&](pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
(void) walkPimCoreBlockStructurally(
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
});
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
llvm::SmallVector<unsigned, 2> lanes;
lanes.push_back(0);
if (coreBatchOp.getLaneCount() > 1)
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
for (unsigned lane : lanes) {
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
(void) walkPimCoreBlockStructurally(
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
}
}); });
bool hasFailed = false; bool hasFailed = false;
for (memref::CopyOp copyOp : copyWorklist) { for (const MemRefCopyWorkItem& workItem : copyWorklist) {
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) { memref::CopyOp copyOp = workItem.copyOp;
copyOp.emitOpError("failed to lower memref.copy inside PIM core body"); rewriter.setInsertionPoint(copyOp);
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
hasFailed = true; hasFailed = true;
}
} }
if (hasFailed) { if (hasFailed) {
signalPassFailure(); signalPassFailure();
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size"); auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
pim::PimMemCopyOp::create( pim::PimMemCopyHostToDevOp::create(
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr); rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
rewriter.eraseOp(mapOp); rewriter.eraseOp(mapOp);
return success(); return success();
@@ -1,9 +0,0 @@
add_pim_library(OMPimHostConstantMaterialization
MaterializeHostConstantsPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -1,161 +0,0 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
DominanceInfo dominance(coreOp);
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
ops.push_back(op);
});
for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op->emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
operand.set(cachedValue->second);
continue;
}
auto type = dyn_cast<ShapedType>(originalValue.getType());
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
: FailureOr<uint64_t>(failure());
auto totalBytesAttr =
succeeded(totalBytes)
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(totalBytesAttr)
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(op);
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
originalType,
zeroOffset,
hostOffset,
deviceDst,
getGlobalOp.getResult(),
*totalBytesAttr)
.getOutput();
cachedByType[originalType] = copiedValue;
operand.set(copiedValue);
}
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim4_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp()); || isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
} }
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) { static bool isConstantGlobalView(Value value) {
while (true) { while (true) {
Operation* defOp = value.getDefiningOp(); Operation* defOp = value.getDefiningOp();
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op); memref::GetGlobalOp>(op);
} }
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
Value base;
if (succeeded(resolvedAddress)) {
base = resolvedAddress->base;
}
else {
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
base = compiledAddress->base;
}
if (isa<BlockArgument>(base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> { struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -311,10 +316,10 @@ private:
} }
if (isExplicitHostMemCopyOperand(&op, operandIndex)) { if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) { if (!isHostAddressableValue(operand, knowledge)) {
diagnostics.report(&op, [&](Operation* illegalOp) { diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage"; << " must be backed by host-addressable storage";
}); });
hasFailure = true; hasFailure = true;
} }
-2
View File
@@ -21,8 +21,6 @@ std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass(); std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimCodePass(); std::unique_ptr<mlir::Pass> createEmitPimCodePass();
-1
View File
@@ -78,7 +78,6 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createPimMemoryCoalescingPass); registerPass(createPimMemoryCoalescingPass);
registerPass(createMergeComputeNodesPass); registerPass(createMergeComputeNodesPass);
registerPass(createPimHostConstantFoldingPass); registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeHostConstantsPass);
registerPass(createPimVerificationPass); registerPass(createPimVerificationPass);
registerPass(createEmitPimCodePass); registerPass(createEmitPimCodePass);
} }
-1
View File
@@ -11,7 +11,6 @@ PIM_PASS_LABELS = (
("SpatialToPimPass", "Spatial to PIM"), ("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"), ("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"), ("HostConstantFoldingPass", "Fold Host Constants"),
("MaterializeHostConstantsPass", "Materialize Host Constants"),
("VerificationPass", "Verify PIM"), ("VerificationPass", "Verify PIM"),
("EmitPimCodePass", "Emit PIM Code"), ("EmitPimCodePass", "Emit PIM Code"),
) )