fix missed failing tests for channels
moderate refactor
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
static int64_t getValueSizeInBytes(Value value) {
|
||||
auto type = dyn_cast<ShapedType>(value.getType());
|
||||
if (!type || !type.hasStaticShape())
|
||||
return -1;
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
OpBuilder rewriter(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op.getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op.emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end()) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t totalBytes = getValueSizeInBytes(originalValue);
|
||||
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
|
||||
op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(&op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op.getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
|
||||
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||
operand.set(hostToDevCopy.getResult());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim3_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user