use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -22,16 +23,6 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
@@ -51,7 +42,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
@@ -113,8 +104,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
getOrCreateHostIndexConstant(constantFolder, op, 0),
|
||||
getOrCreateHostIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ),
|
||||
getOrCreateIndexConstant(constantFolder, op, 0),
|
||||
getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset) ),
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||
|
||||
Reference in New Issue
Block a user