Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
This commit is contained in:
@@ -168,8 +168,8 @@ Each validation run writes artifacts in the model workspace, for example under
|
||||
|
||||
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
|
||||
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
|
||||
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and
|
||||
`pim4_materialized.mlir` when an output directory is available.
|
||||
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
|
||||
available.
|
||||
|
||||
To rerun the simulator manually with tracing after validation has produced a
|
||||
`raptor/pim/` directory:
|
||||
|
||||
@@ -123,7 +123,6 @@ add_pim_library(OMPIMAccel
|
||||
OMPimBufferization
|
||||
OMPimMemoryCoalescing
|
||||
OMPimHostConstantFolding
|
||||
OMPimHostConstantMaterialization
|
||||
OMPimVerification
|
||||
MLIRTensorInferTypeOpInterfaceImpl
|
||||
)
|
||||
|
||||
@@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
@@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
||||
mlir::Value current = weight;
|
||||
|
||||
while (true) {
|
||||
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
|
||||
current = directAlias;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto defOp = current.getDefiningOp()) {
|
||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||
@@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||
llvm::SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getMixedOffsets().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||
@@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
||||
return mlir::failure();
|
||||
}
|
||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||
@@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
||||
return view;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
||||
current = subview.getSource();
|
||||
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
||||
current = collapse.getSrc();
|
||||
else
|
||||
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
@@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
|
||||
current = loopAlias;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||
if (!weightIndex)
|
||||
return mlir::failure();
|
||||
|
||||
@@ -28,6 +28,8 @@ struct CappedDiagnosticReporter {
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||
}
|
||||
|
||||
void noteFailures(int64_t count) { numFailures += count; }
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
|
||||
@@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils
|
||||
OMPimBufferization
|
||||
OMPimMemoryCoalescing
|
||||
OMPimHostConstantFolding
|
||||
OMPimHostConstantMaterialization
|
||||
OMPimVerification
|
||||
OMPimPasses
|
||||
OMONNXToSpatial
|
||||
|
||||
@@ -32,6 +32,7 @@
|
||||
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/Diagnostics.hpp"
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Common/Support/ReportUtils.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -996,12 +997,44 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
struct CoreEmissionResult {
|
||||
static constexpr size_t kMaxStoredCodegenDiagnostics = 8;
|
||||
|
||||
struct DiagnosticRecord {
|
||||
Operation* op = nullptr;
|
||||
std::string message;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
|
||||
MemoryReportRow reportRow;
|
||||
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
|
||||
MemoryPlanArtifacts livenessArtifacts;
|
||||
llvm::SmallVector<DiagnosticRecord, kMaxStoredCodegenDiagnostics> diagnostics;
|
||||
size_t diagnosticCount = 0;
|
||||
|
||||
void recordDiagnostic(Operation* op, StringRef message) {
|
||||
++diagnosticCount;
|
||||
if (diagnostics.size() < kMaxStoredCodegenDiagnostics)
|
||||
diagnostics.push_back({op, message.str()});
|
||||
}
|
||||
};
|
||||
|
||||
static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) {
|
||||
StaticValueKnowledge knowledge;
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||
StaticValueKnowledge knowledge;
|
||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
template <typename MapTy>
|
||||
class ScopedMapBindings {
|
||||
using KeyTy = typename MapTy::key_type;
|
||||
@@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
|
||||
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
|
||||
if (failed(weightView)) {
|
||||
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
|
||||
std::string message;
|
||||
llvm::raw_string_ostream os(message);
|
||||
os << "requires a statically resolvable dense global weight view during PIM codegen; weight="
|
||||
<< vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
|
||||
result.recordDiagnostic(vmmOp, os.str());
|
||||
return failure();
|
||||
}
|
||||
if (weightView->shape.size() != 2) {
|
||||
std::string message;
|
||||
llvm::raw_string_ostream os(message);
|
||||
os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=[";
|
||||
llvm::interleaveComma(weightView->shape, os);
|
||||
os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
|
||||
result.recordDiagnostic(vmmOp, os.str());
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||
deviceMemory.allocateCore(coreOp);
|
||||
|
||||
int64_t processedOperations = codeGenCoreOps(
|
||||
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
|
||||
StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp);
|
||||
int64_t processedOperations =
|
||||
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot);
|
||||
if (processedOperations < 0) {
|
||||
result.status = CompilerFailure;
|
||||
return result;
|
||||
}
|
||||
assert(processedOperations > 0);
|
||||
result.reportRow = deviceMemory.getReportRow();
|
||||
result.usedWeights = std::move(usedWeights);
|
||||
result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
|
||||
@@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
|
||||
|
||||
for (unsigned lane : job.lanes) {
|
||||
StaticValueKnowledge knowledge;
|
||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
|
||||
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
|
||||
StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane);
|
||||
|
||||
deviceMemory.allocateCore(coreBatchOp, lane);
|
||||
coreCodeGen.setBatchLane(lane);
|
||||
@@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
result.status = CompilerFailure;
|
||||
return result;
|
||||
}
|
||||
assert(processedOperations > 0);
|
||||
}
|
||||
|
||||
result.reportRow = deviceMemory.getReportRow();
|
||||
@@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
mlir::parallelFor(
|
||||
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
|
||||
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
Operation* summaryAnchor = nullptr;
|
||||
for (const CoreEmissionResult& result : jobResults) {
|
||||
if (!summaryAnchor && !result.diagnostics.empty())
|
||||
summaryAnchor = result.diagnostics.front().op;
|
||||
for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) {
|
||||
diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; });
|
||||
}
|
||||
size_t unreportedCount = result.diagnosticCount - result.diagnostics.size();
|
||||
diagnostics.noteFailures(static_cast<int64_t>(unreportedCount));
|
||||
}
|
||||
if (diagnostics.hasFailure())
|
||||
diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(),
|
||||
"PIM codegen diagnostic(s)");
|
||||
|
||||
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
|
||||
if (jobResults[jobIndex].status != CompilerSuccess)
|
||||
return jobResults[jobIndex].status;
|
||||
|
||||
@@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimHostConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimMemoryCoalescingPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
|
||||
@@ -375,6 +375,57 @@ static void cloneHelperChain(Value sourceValue,
|
||||
}
|
||||
}
|
||||
|
||||
static bool isHostStaticReturnValue(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
|
||||
return false;
|
||||
value = definingOp->getOperand(0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
materializeHostStaticReturnValue(IRRewriter& rewriter, Value value, OperationFolder& constantFolder) {
|
||||
llvm::SmallVector<Operation*> chain;
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return failure();
|
||||
chain.push_back(definingOp);
|
||||
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
break;
|
||||
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
|
||||
return failure();
|
||||
value = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
if (chain.empty())
|
||||
return failure();
|
||||
|
||||
IRMapping mapping;
|
||||
Value clonedValue;
|
||||
for (Operation* op : llvm::reverse(chain)) {
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
|
||||
clonedValue = getOrCreateConstantLike(constantFolder, constantOp);
|
||||
mapping.map(op->getResult(0), clonedValue);
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
clonedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
|
||||
return clonedValue;
|
||||
}
|
||||
|
||||
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
@@ -444,7 +495,30 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
|
||||
auto materializeDirectHostReturn = [&](size_t returnIndex,
|
||||
Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain) -> ReturnPathLoweringResult {
|
||||
rewriter.setInsertionPointAfter(producerOp);
|
||||
auto hostStaticValue = materializeHostStaticReturnValue(rewriter, sourceValue, constantFolder);
|
||||
if (failed(hostStaticValue))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
|
||||
Value hostReturnValue = *hostStaticValue;
|
||||
if (!helperChain.empty())
|
||||
cloneHelperChain(hostReturnValue, helperChain, rewriter, constantFolder, hostReturnValue);
|
||||
|
||||
outputTensors[returnIndex] =
|
||||
[hostReturnValue](IRRewriter& rewriter, Location loc) -> Value { return hostReturnValue; };
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
};
|
||||
|
||||
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||
if (isHostStaticReturnValue(storedValue)) {
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(op);
|
||||
return materializeDirectHostReturn(returnUse->returnIndex, storedValue, returnUse->helperChain);
|
||||
}
|
||||
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
@@ -470,6 +544,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
if (isHostStaticReturnValue(storedValue))
|
||||
return materializeDirectHostReturn(resultIndexInReturn, storedValue, {});
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
|
||||
if (failed(byteSize))
|
||||
|
||||
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
add_subdirectory(Transforms/MemoryCoalescing)
|
||||
add_subdirectory(Transforms/HostConstantFolding)
|
||||
add_subdirectory(Transforms/HostConstantMaterialization)
|
||||
add_subdirectory(Transforms/Verification)
|
||||
|
||||
add_pim_library(PimOps
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -13,7 +14,9 @@ using namespace bufferization;
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
if (isHostBackedPimAddress(memrefValue)) {
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
return PimMemCopyOp::create(
|
||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
|
||||
@@ -1,9 +1,70 @@
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static bool isCoreBatchInputArgument(Value value) {
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg)
|
||||
return false;
|
||||
|
||||
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
|
||||
if (!coreBatchOp)
|
||||
return false;
|
||||
|
||||
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
|
||||
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
|
||||
}
|
||||
|
||||
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallPtrSet<Value, 8> visited;
|
||||
while (value && visited.insert(value).second) {
|
||||
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||
if (alias)
|
||||
value = alias;
|
||||
|
||||
if (auto aliased = knowledge.aliases.lookup(value)) {
|
||||
value = aliased;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
|
||||
return base;
|
||||
|
||||
if (isa<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
if (value)
|
||||
return value;
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
|
||||
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
|
||||
return failure();
|
||||
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
|
||||
}
|
||||
|
||||
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
|
||||
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||
if (alias)
|
||||
value = alias;
|
||||
|
||||
auto resolved = resolveContiguousAddress(value, knowledge);
|
||||
if (succeeded(resolved))
|
||||
return resolved->base;
|
||||
|
||||
auto compiled = compileContiguousAddressExpr(value);
|
||||
if (failed(compiled)) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return value;
|
||||
return failure();
|
||||
}
|
||||
return compiled->base;
|
||||
}
|
||||
|
||||
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto base = getPimStorageBase(value, knowledge);
|
||||
if (failed(base))
|
||||
return false;
|
||||
|
||||
if (isCoreBatchInputArgument(*base))
|
||||
return true;
|
||||
|
||||
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
|
||||
}
|
||||
|
||||
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto base = getPimStorageBase(value, knowledge);
|
||||
if (failed(base))
|
||||
return false;
|
||||
|
||||
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
|
||||
}
|
||||
|
||||
@@ -2,11 +2,19 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
|
||||
|
||||
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
@@ -15,6 +16,7 @@
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
@@ -27,24 +29,71 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
struct MemRefCopyWorkItem {
|
||||
memref::CopyOp copyOp;
|
||||
StaticValueKnowledge knowledge;
|
||||
};
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
|
||||
StaticValueKnowledge knowledge;
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return failure();
|
||||
static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||
StaticValueKnowledge knowledge;
|
||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
static LogicalResult
|
||||
lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return failure();
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
|
||||
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
|
||||
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
|
||||
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
|
||||
|
||||
if (targetIsDevice && sourceIsHost) {
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
}
|
||||
else if (targetIsHost && sourceIsDevice) {
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
}
|
||||
else if (targetIsDevice && sourceIsDevice) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
else {
|
||||
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
|
||||
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
|
||||
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
|
||||
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
|
||||
<< " device=" << targetIsDevice;
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
RewritePatternSet memrefCopyPatterns(ctx);
|
||||
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
|
||||
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
|
||||
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
|
||||
memrefCopyApplicator.applyDefaultCostModel();
|
||||
PatternRewriter rewriter(ctx);
|
||||
|
||||
SmallVector<memref::CopyOp> copyWorklist;
|
||||
moduleOp.walk([&](memref::CopyOp copyOp) {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
copyWorklist.push_back(copyOp);
|
||||
SmallVector<MemRefCopyWorkItem> copyWorklist;
|
||||
llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
|
||||
auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||
if (seenCopyOps.insert(copyOp.getOperation()).second)
|
||||
copyWorklist.push_back({copyOp, knowledge});
|
||||
};
|
||||
|
||||
moduleOp.walk([&](pim::PimCoreOp coreOp) {
|
||||
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
|
||||
(void) walkPimCoreBlockStructurally(
|
||||
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||
addCopyOp(copyOp, opKnowledge);
|
||||
return success();
|
||||
});
|
||||
});
|
||||
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
llvm::SmallVector<unsigned, 2> lanes;
|
||||
lanes.push_back(0);
|
||||
if (coreBatchOp.getLaneCount() > 1)
|
||||
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
|
||||
for (unsigned lane : lanes) {
|
||||
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
|
||||
(void) walkPimCoreBlockStructurally(
|
||||
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||
addCopyOp(copyOp, opKnowledge);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
bool hasFailed = false;
|
||||
for (memref::CopyOp copyOp : copyWorklist) {
|
||||
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
|
||||
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
|
||||
for (const MemRefCopyWorkItem& workItem : copyWorklist) {
|
||||
memref::CopyOp copyOp = workItem.copyOp;
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
|
||||
hasFailed = true;
|
||||
}
|
||||
}
|
||||
if (hasFailed) {
|
||||
signalPassFailure();
|
||||
|
||||
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimMemCopyOp::create(
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
add_pim_library(OMPimHostConstantMaterialization
|
||||
MaterializeHostConstantsPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
-161
@@ -1,161 +0,0 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
DominanceInfo dominance(coreOp);
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op->emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto type = dyn_cast<ShapedType>(originalValue.getType());
|
||||
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
|
||||
: FailureOr<uint64_t>(failure());
|
||||
auto totalBytesAttr =
|
||||
succeeded(totalBytes)
|
||||
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
|
||||
: FailureOr<IntegerAttr>(failure());
|
||||
if (failed(totalBytesAttr)
|
||||
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
||||
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
||||
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
zeroOffset,
|
||||
hostOffset,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
*totalBytesAttr)
|
||||
.getOutput();
|
||||
|
||||
cachedByType[originalType] = copiedValue;
|
||||
operand.set(copiedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
OperationFolder constantFolder(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (isa<pim::PimConcatOp>(op))
|
||||
hostCompactOps.push_back(&op);
|
||||
|
||||
for (Operation* op : hostCompactOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto concatOp = cast<pim::PimConcatOp>(op);
|
||||
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim4_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
if (succeeded(resolvedAddress))
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(compiledAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||
memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
Value base;
|
||||
if (succeeded(resolvedAddress)) {
|
||||
base = resolvedAddress->base;
|
||||
}
|
||||
else {
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
base = compiledAddress->base;
|
||||
}
|
||||
|
||||
if (isa<BlockArgument>(base))
|
||||
return true;
|
||||
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||
|
||||
@@ -311,10 +316,10 @@ private:
|
||||
}
|
||||
|
||||
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
if (!isHostAddressableValue(operand, knowledge)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
<< " must be backed by host-addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
@@ -21,8 +21,6 @@ std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimCodePass();
|
||||
|
||||
@@ -78,7 +78,6 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createPimMemoryCoalescingPass);
|
||||
registerPass(createMergeComputeNodesPass);
|
||||
registerPass(createPimHostConstantFoldingPass);
|
||||
registerPass(createPimMaterializeHostConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimCodePass);
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ PIM_PASS_LABELS = (
|
||||
("SpatialToPimPass", "Spatial to PIM"),
|
||||
("PimBufferizationPass", "Bufferize PIM"),
|
||||
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
||||
("VerificationPass", "Verify PIM"),
|
||||
("EmitPimCodePass", "Emit PIM Code"),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user