This commit is contained in:
@@ -23,11 +23,11 @@ namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
return operandIndex == 3;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -39,7 +39,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) {
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
@@ -48,6 +51,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
@@ -105,16 +111,17 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
|
||||
.getOutput();
|
||||
}
|
||||
else {
|
||||
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
.getOutput();
|
||||
copiedValue =
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
getOrCreateHostIndexConstant(op, 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
cachedByType[originalType] = copiedValue;
|
||||
@@ -134,6 +141,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
OperationFolder constantFolder(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
@@ -141,10 +149,10 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, hasFailure);
|
||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure);
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
|
||||
Reference in New Issue
Block a user