ff36729140
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
158 lines
5.9 KiB
C++
158 lines
5.9 KiB
C++
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
template <typename CoreOpTy>
|
|
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|
IRRewriter& rewriter,
|
|
OperationFolder& constantFolder,
|
|
bool& hasFailure) {
|
|
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
|
DominanceInfo dominance(coreOp);
|
|
SmallVector<Operation*> ops;
|
|
coreOp.getBody().front().walk([&](Operation* op) {
|
|
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
|
ops.push_back(op);
|
|
});
|
|
|
|
for (Operation* op : ops) {
|
|
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
|
continue;
|
|
|
|
for (OpOperand& operand : op->getOpOperands()) {
|
|
Value originalValue = operand.get();
|
|
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
|
continue;
|
|
|
|
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
|
if (failed(resolvedAddress))
|
|
continue;
|
|
|
|
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|
if (!getGlobalOp)
|
|
continue;
|
|
|
|
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
|
if (!originalType || !originalType.hasStaticShape()) {
|
|
op->emitOpError("host constant materialization requires a static memref operand");
|
|
hasFailure = true;
|
|
continue;
|
|
}
|
|
|
|
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
|
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
|
auto cachedValue = cachedByType.find(originalType);
|
|
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
|
operand.set(cachedValue->second);
|
|
continue;
|
|
}
|
|
|
|
int64_t totalBytes = -1;
|
|
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape())
|
|
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
|
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
|
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
|
hasFailure = true;
|
|
continue;
|
|
}
|
|
|
|
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
|
|
|
rewriter.setInsertionPoint(op);
|
|
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
|
Value deviceDst = localAlloc;
|
|
if (contiguousType != originalType)
|
|
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
|
|
|
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
|
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
|
Value copiedValue =
|
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
|
op->getLoc(),
|
|
originalType,
|
|
zeroOffset,
|
|
hostOffset,
|
|
deviceDst,
|
|
getGlobalOp.getResult(),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
|
.getOutput();
|
|
|
|
cachedByType[originalType] = copiedValue;
|
|
operand.set(copiedValue);
|
|
}
|
|
}
|
|
}
|
|
|
|
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
|
|
|
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
|
StringRef getDescription() const override {
|
|
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
ModuleOp moduleOp = getOperation();
|
|
IRRewriter rewriter(moduleOp.getContext());
|
|
OperationFolder constantFolder(moduleOp.getContext());
|
|
bool hasFailure = false;
|
|
|
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
|
if (funcOp.isExternal())
|
|
continue;
|
|
|
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
|
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
|
|
|
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
|
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
|
|
|
SmallVector<Operation*> hostCompactOps;
|
|
for (Operation& op : funcOp.getBody().front())
|
|
if (isa<pim::PimConcatOp>(op))
|
|
hostCompactOps.push_back(&op);
|
|
|
|
for (Operation* op : hostCompactOps) {
|
|
rewriter.setInsertionPoint(op);
|
|
auto concatOp = cast<pim::PimConcatOp>(op);
|
|
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
|
hasFailure = true;
|
|
}
|
|
}
|
|
|
|
if (hasFailure) {
|
|
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
dumpModule(moduleOp, "pim4_materialized");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
|
return std::make_unique<MaterializeHostConstantsPass>();
|
|
}
|
|
|
|
} // namespace onnx_mlir
|