remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
avoid spammy pim codegen diagnostics
This commit is contained in:
@@ -4,7 +4,6 @@ add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
add_subdirectory(Transforms/MemoryCoalescing)
|
||||
add_subdirectory(Transforms/HostConstantFolding)
|
||||
add_subdirectory(Transforms/HostConstantMaterialization)
|
||||
add_subdirectory(Transforms/Verification)
|
||||
|
||||
add_pim_library(PimOps
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
@@ -13,7 +14,9 @@ using namespace bufferization;
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
@@ -29,13 +32,21 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
if (isHostBackedPimAddress(memrefValue)) {
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
return PimMemCopyOp::create(
|
||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
bool isContiguous =
|
||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
|
||||
@@ -1,9 +1,70 @@
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static bool isCoreBatchInputArgument(Value value) {
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg)
|
||||
return false;
|
||||
|
||||
auto coreBatchOp = dyn_cast_or_null<onnx_mlir::pim::PimCoreBatchOp>(blockArg.getOwner()->getParentOp());
|
||||
if (!coreBatchOp)
|
||||
return false;
|
||||
|
||||
unsigned firstInputArg = 1 + coreBatchOp.getWeights().size();
|
||||
return static_cast<unsigned>(blockArg.getArgNumber()) >= firstInputArg;
|
||||
}
|
||||
|
||||
static FailureOr<Value> getPimStorageBase(Value value, const onnx_mlir::StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallPtrSet<Value, 8> visited;
|
||||
while (value && visited.insert(value).second) {
|
||||
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||
if (alias)
|
||||
value = alias;
|
||||
|
||||
if (auto aliased = knowledge.aliases.lookup(value)) {
|
||||
value = aliased;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto base = onnx_mlir::pim::getPimAddressBase(value, knowledge); succeeded(base))
|
||||
return base;
|
||||
|
||||
if (isa<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
if (value)
|
||||
return value;
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
|
||||
@@ -11,3 +72,40 @@ FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& build
|
||||
return failure();
|
||||
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
|
||||
}
|
||||
|
||||
FailureOr<Value> onnx_mlir::pim::getPimAddressBase(Value value, const StaticValueKnowledge& knowledge) {
|
||||
Value alias = resolveLoopCarriedAlias(value, knowledge);
|
||||
if (alias)
|
||||
value = alias;
|
||||
|
||||
auto resolved = resolveContiguousAddress(value, knowledge);
|
||||
if (succeeded(resolved))
|
||||
return resolved->base;
|
||||
|
||||
auto compiled = compileContiguousAddressExpr(value);
|
||||
if (failed(compiled)) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return value;
|
||||
return failure();
|
||||
}
|
||||
return compiled->base;
|
||||
}
|
||||
|
||||
bool onnx_mlir::pim::isHostBackedPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto base = getPimStorageBase(value, knowledge);
|
||||
if (failed(base))
|
||||
return false;
|
||||
|
||||
if (isCoreBatchInputArgument(*base))
|
||||
return true;
|
||||
|
||||
return isa_and_nonnull<memref::GetGlobalOp>(base->getDefiningOp());
|
||||
}
|
||||
|
||||
bool onnx_mlir::pim::isDeviceLocalPimAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto base = getPimStorageBase(value, knowledge);
|
||||
if (failed(base))
|
||||
return false;
|
||||
|
||||
return isa_and_nonnull<memref::AllocOp>(base->getDefiningOp());
|
||||
}
|
||||
|
||||
@@ -2,11 +2,19 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
|
||||
|
||||
mlir::FailureOr<mlir::Value> getPimAddressBase(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
bool isHostBackedPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
bool isDeviceLocalPimAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
@@ -15,6 +16,7 @@
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
@@ -27,24 +29,71 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
struct MemRefCopyWorkItem {
|
||||
memref::CopyOp copyOp;
|
||||
StaticValueKnowledge knowledge;
|
||||
};
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
static StaticValueKnowledge seedCoreKnowledge(pim::PimCoreOp coreOp) {
|
||||
StaticValueKnowledge knowledge;
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return failure();
|
||||
static StaticValueKnowledge seedCoreBatchKnowledge(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||
StaticValueKnowledge knowledge;
|
||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
static LogicalResult
|
||||
lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const StaticValueKnowledge& knowledge) {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return failure();
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
|
||||
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
|
||||
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
|
||||
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
|
||||
|
||||
if (targetIsDevice && sourceIsHost) {
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
}
|
||||
else if (targetIsHost && sourceIsDevice) {
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
}
|
||||
else if (targetIsDevice && sourceIsDevice) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
@@ -53,10 +102,19 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
*sizeAttr);
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
else {
|
||||
copyOp.emitOpError() << "failed to classify memref.copy endpoints: source=" << copyOp.getSource()
|
||||
<< " type=" << copyOp.getSource().getType() << " host=" << sourceIsHost
|
||||
<< " device=" << sourceIsDevice << ", target=" << copyOp.getTarget()
|
||||
<< " type=" << copyOp.getTarget().getType() << " host=" << targetIsHost
|
||||
<< " device=" << targetIsDevice;
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
@@ -100,25 +158,46 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
RewritePatternSet memrefCopyPatterns(ctx);
|
||||
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
|
||||
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
|
||||
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
|
||||
memrefCopyApplicator.applyDefaultCostModel();
|
||||
PatternRewriter rewriter(ctx);
|
||||
|
||||
SmallVector<memref::CopyOp> copyWorklist;
|
||||
moduleOp.walk([&](memref::CopyOp copyOp) {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
copyWorklist.push_back(copyOp);
|
||||
SmallVector<MemRefCopyWorkItem> copyWorklist;
|
||||
llvm::SmallPtrSet<Operation*, 16> seenCopyOps;
|
||||
auto addCopyOp = [&](memref::CopyOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||
if (seenCopyOps.insert(copyOp.getOperation()).second)
|
||||
copyWorklist.push_back({copyOp, knowledge});
|
||||
};
|
||||
|
||||
moduleOp.walk([&](pim::PimCoreOp coreOp) {
|
||||
StaticValueKnowledge knowledge = seedCoreKnowledge(coreOp);
|
||||
(void) walkPimCoreBlockStructurally(
|
||||
coreOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||
addCopyOp(copyOp, opKnowledge);
|
||||
return success();
|
||||
});
|
||||
});
|
||||
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
llvm::SmallVector<unsigned, 2> lanes;
|
||||
lanes.push_back(0);
|
||||
if (coreBatchOp.getLaneCount() > 1)
|
||||
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
|
||||
for (unsigned lane : lanes) {
|
||||
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, lane);
|
||||
(void) walkPimCoreBlockStructurally(
|
||||
coreBatchOp.getBody().front(), knowledge, [&](Operation& op, const StaticValueKnowledge& opKnowledge) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(&op))
|
||||
addCopyOp(copyOp, opKnowledge);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
bool hasFailed = false;
|
||||
for (memref::CopyOp copyOp : copyWorklist) {
|
||||
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
|
||||
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
|
||||
for (const MemRefCopyWorkItem& workItem : copyWorklist) {
|
||||
memref::CopyOp copyOp = workItem.copyOp;
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
if (failed(lowerMemRefCopyToPimCopy(copyOp, rewriter, workItem.knowledge)))
|
||||
hasFailed = true;
|
||||
}
|
||||
}
|
||||
if (hasFailed) {
|
||||
signalPassFailure();
|
||||
|
||||
@@ -128,7 +128,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimMemCopyOp::create(
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
add_pim_library(OMPimHostConstantMaterialization
|
||||
MaterializeHostConstantsPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
-161
@@ -1,161 +0,0 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
DominanceInfo dominance(coreOp);
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op->emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto type = dyn_cast<ShapedType>(originalValue.getType());
|
||||
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
|
||||
: FailureOr<uint64_t>(failure());
|
||||
auto totalBytesAttr =
|
||||
succeeded(totalBytes)
|
||||
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
|
||||
: FailureOr<IntegerAttr>(failure());
|
||||
if (failed(totalBytesAttr)
|
||||
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
||||
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
||||
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
zeroOffset,
|
||||
hostOffset,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
*totalBytesAttr)
|
||||
.getOutput();
|
||||
|
||||
cachedByType[originalType] = copiedValue;
|
||||
operand.set(copiedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
OperationFolder constantFolder(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (isa<pim::PimConcatOp>(op))
|
||||
hostCompactOps.push_back(&op);
|
||||
|
||||
for (Operation* op : hostCompactOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto concatOp = cast<pim::PimConcatOp>(op);
|
||||
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim4_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -46,19 +46,6 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
if (succeeded(resolvedAddress))
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(compiledAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
@@ -138,6 +125,24 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||
memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
Value base;
|
||||
if (succeeded(resolvedAddress)) {
|
||||
base = resolvedAddress->base;
|
||||
}
|
||||
else {
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
base = compiledAddress->base;
|
||||
}
|
||||
|
||||
if (isa<BlockArgument>(base))
|
||||
return true;
|
||||
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||
|
||||
@@ -311,10 +316,10 @@ private:
|
||||
}
|
||||
|
||||
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
if (!isHostAddressableValue(operand, knowledge)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
<< " must be backed by host-addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user