use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -749,18 +749,12 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
|
||||
@@ -77,14 +77,12 @@ mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::m
|
||||
|
||||
/// Resolves a value to contiguous backing storage when that storage can be
|
||||
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
/// Statically evaluates index-like SSA values, including simple integer
|
||||
/// arithmetic and loop facts recorded in `knowledge`.
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
|
||||
|
||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||
|
||||
@@ -17,4 +17,18 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -11,4 +11,8 @@ llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -14,13 +14,17 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||
return current->getBlock();
|
||||
|
||||
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
||||
return moduleOp.getBody();
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
@@ -28,9 +32,9 @@ Block* getHostConstantBlock(Operation* anchorOp) {
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
||||
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
@@ -42,9 +46,9 @@ Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attr
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
||||
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
@@ -57,28 +61,18 @@ Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attri
|
||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
||||
return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
||||
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
||||
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
||||
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() );
|
||||
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
||||
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() );
|
||||
}
|
||||
|
||||
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() );
|
||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
Value createAffineApplyOrFoldedConstant(
|
||||
@@ -95,7 +89,7 @@ Value createAffineApplyOrFoldedConstant(
|
||||
SmallVector<Attribute> foldedResults;
|
||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
||||
return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt());
|
||||
return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt());
|
||||
}
|
||||
|
||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||
|
||||
@@ -8,27 +8,23 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type);
|
||||
mlir::Value getOrCreateConstant(mlir::OperationFolder& folder,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type);
|
||||
mlir::Value getOrCreateConstant(mlir::RewriterBase& rewriter,
|
||||
mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
|
||||
@@ -178,10 +178,6 @@ std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::V
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
return resolveWeightIndex(weightOwner, vmmOp.getWeight());
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallVector<mlir::Operation*> viewOps;
|
||||
|
||||
@@ -46,7 +46,6 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user