Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone

This commit is contained in:
ilgeco
2026-06-05 10:20:09 +02:00
20 changed files with 458 additions and 256 deletions
+2 -2
View File
@@ -168,8 +168,8 @@ Each validation run writes artifacts in the model workspace, for example under
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and
`pim4_materialized.mlir` when an output directory is available.
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
available.
To rerun the simulator manually with tracing after validation has produced a
`raptor/pim/` directory:
-1
View File
@@ -123,7 +123,6 @@ add_pim_library(OMPIMAccel
OMPimBufferization
OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification
MLIRTensorInferTypeOpInterfaceImpl
)
+62 -18
View File
@@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
return mlir::failure();
return strides;
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
@@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
mlir::Value current = weight;
while (true) {
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
current = directAlias;
continue;
}
if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
@@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
llvm::SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getMixedOffsets().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0);
@@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure();
}
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
view.strides = std::move(*resultStrides);
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
view.strides = std::move(*resultStrides);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
return mlir::failure();
}
auto resolvedOffset = offsetExpr.evaluate(knowledge);
@@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return view;
}
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
viewOps.push_back(defOp);
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
current = subview.getSource();
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
current = collapse.getSrc();
else
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
current = subview.getSource();
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = collapse.getSrc();
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = expand.getSrc();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
viewOps.push_back(defOp);
current = castOp.getSource();
continue;
}
@@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure();
}
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
current = loopAlias;
continue;
}
auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex)
return mlir::failure();
+2
View File
@@ -28,6 +28,8 @@ struct CappedDiagnosticReporter {
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
void noteFailures(int64_t count) { numFailures += count; }
bool hasFailure() const { return numFailures != 0; }
private:
-1
View File
@@ -31,7 +31,6 @@ add_pim_library(OMPimCompilerUtils
OMPimBufferization
OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimHostConstantMaterialization
OMPimVerification
OMPimPasses
OMONNXToSpatial
+66 -9
View File
@@ -32,6 +32,7 @@
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/Diagnostics.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -996,12 +997,44 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
}
struct CoreEmissionResult {
static constexpr size_t kMaxStoredCodegenDiagnostics = 8;
struct DiagnosticRecord {
Operation* op = nullptr;
std::string message;
};
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
MemoryReportRow reportRow;
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
MemoryPlanArtifacts livenessArtifacts;
llvm::SmallVector<DiagnosticRecord, kMaxStoredCodegenDiagnostics> diagnostics;
size_t diagnosticCount = 0;
void recordDiagnostic(Operation* op, StringRef message) {
++diagnosticCount;
if (diagnostics.size() < kMaxStoredCodegenDiagnostics)
diagnostics.push_back({op, message.str()});
}
};
static StaticValueKnowledge seedCoreCodegenKnowledge(pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
static StaticValueKnowledge seedCoreBatchCodegenKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
template <typename MapTy>
class ScopedMapBindings {
using KeyTy = typename MapTy::key_type;
@@ -1422,7 +1455,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
if (failed(weightView)) {
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
std::string message;
llvm::raw_string_ostream os(message);
os << "requires a statically resolvable dense global weight view during PIM codegen; weight="
<< vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure();
}
if (weightView->shape.size() != 2) {
std::string message;
llvm::raw_string_ostream os(message);
os << "requires a rank-2 matrix weight view during PIM codegen; resolved shape=[";
llvm::interleaveComma(weightView->shape, os);
os << "] weight=" << vmmOp.getWeight() << " type=" << vmmOp.getWeight().getType();
result.recordDiagnostic(vmmOp, os.str());
return failure();
}
@@ -1463,13 +1509,13 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
StaticValueKnowledge knowledge = seedCoreCodegenKnowledge(coreOp);
int64_t processedOperations =
codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, knowledge, coreOp.getOperation(), resolveWeightSlot);
if (processedOperations < 0) {
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights);
result.livenessArtifacts = deviceMemory.getLivenessArtifacts();
@@ -1480,10 +1526,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
for (unsigned lane : job.lanes) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
StaticValueKnowledge knowledge = seedCoreBatchCodegenKnowledge(coreBatchOp, lane);
deviceMemory.allocateCore(coreBatchOp, lane);
coreCodeGen.setBatchLane(lane);
@@ -1498,7 +1541,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
}
result.reportRow = deviceMemory.getReportRow();
@@ -1522,6 +1564,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
mlir::parallelFor(
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
pim::CappedDiagnosticReporter diagnostics;
Operation* summaryAnchor = nullptr;
for (const CoreEmissionResult& result : jobResults) {
if (!summaryAnchor && !result.diagnostics.empty())
summaryAnchor = result.diagnostics.front().op;
for (const CoreEmissionResult::DiagnosticRecord& diagnostic : result.diagnostics) {
diagnostics.report(diagnostic.op, [&](Operation* op) { op->emitError() << diagnostic.message; });
}
size_t unreportedCount = result.diagnosticCount - result.diagnostics.size();
diagnostics.noteFailures(static_cast<int64_t>(unreportedCount));
}
if (diagnostics.hasFailure())
diagnostics.emitSuppressedSummary(summaryAnchor ? summaryAnchor : moduleOp.getOperation(),
"PIM codegen diagnostic(s)");
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess)
return jobResults[jobIndex].status;
-1
View File
@@ -46,7 +46,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
@@ -375,6 +375,57 @@ static void cloneHelperChain(Value sourceValue,
}
}
static bool isHostStaticReturnValue(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
while (Operation* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
return false;
value = definingOp->getOperand(0);
}
return false;
}
static FailureOr<Value>
materializeHostStaticReturnValue(IRRewriter& rewriter, Value value, OperationFolder& constantFolder) {
llvm::SmallVector<Operation*> chain;
llvm::SmallPtrSet<Operation*, 8> visited;
while (Operation* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return failure();
chain.push_back(definingOp);
if (isa<arith::ConstantOp>(definingOp) || definingOp->hasTrait<OpTrait::ConstantLike>())
break;
if (!isReturnHelperChainOp(definingOp) || definingOp->getNumOperands() != 1)
return failure();
value = definingOp->getOperand(0);
}
if (chain.empty())
return failure();
IRMapping mapping;
Value clonedValue;
for (Operation* op : llvm::reverse(chain)) {
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
clonedValue = getOrCreateConstantLike(constantFolder, constantOp);
mapping.map(op->getResult(0), clonedValue);
continue;
}
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
return clonedValue;
}
static FailureOr<Value> emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
@@ -444,7 +495,30 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType());
auto materializeDirectHostReturn = [&](size_t returnIndex,
Value sourceValue,
ArrayRef<Operation*> helperChain) -> ReturnPathLoweringResult {
rewriter.setInsertionPointAfter(producerOp);
auto hostStaticValue = materializeHostStaticReturnValue(rewriter, sourceValue, constantFolder);
if (failed(hostStaticValue))
return ReturnPathLoweringResult::Failure;
Value hostReturnValue = *hostStaticValue;
if (!helperChain.empty())
cloneHelperChain(hostReturnValue, helperChain, rewriter, constantFolder, hostReturnValue);
outputTensors[returnIndex] =
[hostReturnValue](IRRewriter& rewriter, Location loc) -> Value { return hostReturnValue; };
return ReturnPathLoweringResult::Handled;
};
if (auto returnUse = analyzeReturnUse(producedValue)) {
if (isHostStaticReturnValue(storedValue)) {
for (Operation* op : returnUse->helperChain)
markOpToRemove(op);
return materializeDirectHostReturn(returnUse->returnIndex, storedValue, returnUse->helperChain);
}
Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain)
@@ -470,6 +544,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
if (isHostStaticReturnValue(storedValue))
return materializeDirectHostReturn(resultIndexInReturn, storedValue, {});
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(storedTensorType, producerOp, "return-path host copy byte size");
if (failed(byteSize))
-1
View File
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/MemoryCoalescing)
add_subdirectory(Transforms/HostConstantFolding)
add_subdirectory(Transforms/HostConstantMaterialization)
add_subdirectory(Transforms/Verification)
add_pim_library(PimOps
@@ -6,6 +6,7 @@
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
using namespace mlir;
using namespace bufferization;
@@ -13,7 +14,9 @@ using namespace bufferization;
namespace onnx_mlir::pim {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
if (failed(sizeAttr))
return failure();
if (isHostBackedPimAddress(memrefValue)) {
return PimMemCopyHostToDevOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -1,9 +1,70 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
static bool isCoreBatchInputArgument(Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg)
return false;
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
if (!coreBatchOp)
return false;
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
}
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
llvm::SmallPtrSet<Value, 8> visited;
while (value && visited.insert(value).second) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
if (auto aliased = knowledge.aliases.lookup(value)) {
value = aliased;
continue;
}
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
return base;
if (isa<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
return value;
}
if (value)
return value;
return failure();
}
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
return failure();
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
}
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
Value alias = resolveLoopCarriedAlias(value, knowledge);
if (alias)
value = alias;
auto resolved = resolveContiguousAddress(value, knowledge);
if (succeeded(resolved))
return resolved->base;
auto compiled = compileContiguousAddressExpr(value);
if (failed(compiled)) {
if (isa<BlockArgument>(value))
return value;
return failure();
}
return compiled->base;
}
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
if (isCoreBatchInputArgument(*base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
}
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
auto base = getPimStorageBase(value, knowledge);
if (failed(base))
return false;
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
}
@@ -2,11 +2,19 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
namespace pim {
mlir::FailureOr<mlir::IntegerAttr>
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
} // namespace pim
} // namespace onnx_mlir
@@ -8,6 +8,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "Common/PimCommon.hpp"
@@ -15,6 +16,7 @@
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
@@ -27,24 +29,71 @@ namespace onnx_mlir {
namespace {
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
struct MemRefCopyWorkItem {
memref::CopyOp copyOp;
StaticValueKnowledge knowledge;
};
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
return knowledge;
}
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
static LogicalResult
lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
return failure();
if (sourceType.getElementType() != targetType.getElementType())
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
if (targetIsDevice && sourceIsHost) {
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsHost && sourceIsDevice) {
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
zeroOffset,
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
}
else if (targetIsDevice && sourceIsDevice) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
copyOp.getTarget(),
copyOp.getSource(),
*sizeAttr);
rewriter.eraseOp(copyOp);
return success();
}
};
else {
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
<< " device=" << targetIsDevice;
return failure();
}
rewriter.eraseOp(copyOp);
return success();
}
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
}
MLIRContext* ctx = moduleOp.getContext();
RewritePatternSet memrefCopyPatterns(ctx);
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
memrefCopyApplicator.applyDefaultCostModel();
PatternRewriter rewriter(ctx);
SmallVector<memref::CopyOp> copyWorklist;
moduleOp.walk([&](memref::CopyOp copyOp) {
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
copyWorklist.push_back(copyOp);
SmallVector<MemRefCopyWorkItem> copyWorklist;
llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
if (seenCopyOps.insert(copyOp.getOperation()).second)
copyWorklist.push_back({copyOp, knowledge});
};
moduleOp.walk([&](pim::PimCoreOp coreOp) {
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
(void) walkPimCoreBlockStructurally(
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
});
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
llvm::SmallVector<unsigned, 2> lanes;
lanes.push_back(0);
if (coreBatchOp.getLaneCount() > 1)
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
for (unsigned lane : lanes) {
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
(void) walkPimCoreBlockStructurally(
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
addCopyOp(copyOp, opKnowledge);
return success();
});
}
});
bool hasFailed = false;
for (memref::CopyOp copyOp : copyWorklist) {
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
for (const MemRefCopyWorkItem& workItem : copyWorklist) {
memref::CopyOp copyOp = workItem.copyOp;
rewriter.setInsertionPoint(copyOp);
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
hasFailed = true;
}
}
if (hasFailed) {
signalPassFailure();
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
if (failed(sizeAttr))
return failure();
pim::PimMemCopyOp::create(
pim::PimMemCopyHostToDevOp::create(
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
rewriter.eraseOp(mapOp);
return success();
@@ -1,9 +0,0 @@
add_pim_library(OMPimHostConstantMaterialization
MaterializeHostConstantsPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -1,161 +0,0 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
DominanceInfo dominance(coreOp);
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
ops.push_back(op);
});
for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op->emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
operand.set(cachedValue->second);
continue;
}
auto type = dyn_cast<ShapedType>(originalValue.getType());
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
: FailureOr<uint64_t>(failure());
auto totalBytesAttr =
succeeded(totalBytes)
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(totalBytesAttr)
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(op);
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
originalType,
zeroOffset,
hostOffset,
deviceDst,
getGlobalOp.getResult(),
*totalBytesAttr)
.getOutput();
cachedByType[originalType] = copiedValue;
operand.set(copiedValue);
}
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim4_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
while (true) {
Operation* defOp = value.getDefiningOp();
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op);
}
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
Value base;
if (succeeded(resolvedAddress)) {
base = resolvedAddress->base;
}
else {
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
base = compiledAddress->base;
}
if (isa<BlockArgument>(base))
return true;
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -311,10 +316,10 @@ private:
}
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) {
if (!isHostAddressableValue(operand, knowledge)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
<< " must be backed by host-addressable storage";
});
hasFailure = true;
}
-2
View File
@@ -21,8 +21,6 @@ std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimCodePass();
-1
View File
@@ -78,7 +78,6 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createPimMemoryCoalescingPass);
registerPass(createMergeComputeNodesPass);
registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeHostConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimCodePass);
}
-1
View File
@@ -11,7 +11,6 @@ PIM_PASS_LABELS = (
("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"),
("MaterializeHostConstantsPass", "Materialize Host Constants"),
("VerificationPass", "Verify PIM"),
("EmitPimCodePass", "Emit PIM Code"),
)