compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
|||||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
|
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
int64_t offset = 0;
|
||||||
|
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||||
|
return mlir::failure();
|
||||||
|
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
||||||
|
return mlir::failure();
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||||
const StaticValueKnowledge* knowledge) {
|
const StaticValueKnowledge* knowledge) {
|
||||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||||
@@ -539,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
if (failed(sourceStrides))
|
||||||
|
return mlir::failure();
|
||||||
|
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -651,12 +663,16 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||||
|
if (failed(sourceStrides))
|
||||||
|
return mlir::failure();
|
||||||
constantByteOffset +=
|
constantByteOffset +=
|
||||||
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||||
|
if (failed(sourceStrides))
|
||||||
|
return mlir::failure();
|
||||||
CompiledIndexExpr offsetExpr;
|
CompiledIndexExpr offsetExpr;
|
||||||
{
|
{
|
||||||
CompiledIndexExprNode expr;
|
CompiledIndexExprNode expr;
|
||||||
@@ -665,7 +681,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
|
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||||
CompiledIndexExpr operandExpr;
|
CompiledIndexExpr operandExpr;
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||||
CompiledIndexExprNode expr;
|
CompiledIndexExprNode expr;
|
||||||
|
|||||||
@@ -5,8 +5,6 @@
|
|||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
#include "ConstantUtils.hpp"
|
#include "ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -15,16 +13,12 @@ namespace onnx_mlir {
|
|||||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||||
assert(anchorOp && "expected a valid anchor operation");
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
|
||||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
|
||||||
return current->getBlock();
|
|
||||||
|
|
||||||
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
||||||
return &funcOp.getBody().front();
|
return &funcOp.getBody().front();
|
||||||
|
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||||
|
return &funcOp.getBody().front();
|
||||||
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
||||||
return moduleOp.getBody();
|
return moduleOp.getBody();
|
||||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
|
||||||
return &funcOp.getBody().front();
|
|
||||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||||
return moduleOp.getBody();
|
return moduleOp.getBody();
|
||||||
return anchorOp->getBlock();
|
return anchorOp->getBlock();
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
|||||||
if (failed(compiledExpr)) {
|
if (failed(compiledExpr)) {
|
||||||
errs() << "Failed to compile contiguous address for value: ";
|
errs() << "Failed to compile contiguous address for value: ";
|
||||||
value.print(errs());
|
value.print(errs());
|
||||||
|
errs() << " : " << value.getType();
|
||||||
errs() << "\n";
|
errs() << "\n";
|
||||||
llvm_unreachable("Failed to compile contiguous address");
|
llvm_unreachable("Failed to compile contiguous address");
|
||||||
}
|
}
|
||||||
@@ -245,6 +246,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
|||||||
if (failed(resolvedAddress)) {
|
if (failed(resolvedAddress)) {
|
||||||
errs() << "Failed to evaluate contiguous address for value: ";
|
errs() << "Failed to evaluate contiguous address for value: ";
|
||||||
value.print(errs());
|
value.print(errs());
|
||||||
|
errs() << " : " << value.getType();
|
||||||
errs() << "\n";
|
errs() << "\n";
|
||||||
if (auto* definingOp = value.getDefiningOp()) {
|
if (auto* definingOp = value.getDefiningOp()) {
|
||||||
errs() << "Defining op:\n";
|
errs() << "Defining op:\n";
|
||||||
@@ -493,11 +495,15 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const Static
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
auto targetOffset = indexOf(lmvOp.getTargetOffset(), knowledge);
|
||||||
|
auto sourceOffset = indexOf(lmvOp.getSourceOffset(), knowledge);
|
||||||
|
assert(succeeded(targetOffset) && succeeded(sourceOffset)
|
||||||
|
&& "pim.memcp offsets must be statically resolvable during codegen");
|
||||||
emitMemCopyOp("lmv",
|
emitMemCopyOp("lmv",
|
||||||
addressOf(lmvOp.getTarget(), knowledge),
|
addressOf(lmvOp.getTarget(), knowledge),
|
||||||
lmvOp.getTargetOffset(),
|
*targetOffset,
|
||||||
addressOf(lmvOp.getSource(), knowledge),
|
addressOf(lmvOp.getSource(), knowledge),
|
||||||
lmvOp.getSourceOffset(),
|
*sourceOffset,
|
||||||
lmvOp.getSize(),
|
lmvOp.getSize(),
|
||||||
"len");
|
"len");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,9 +101,9 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
|||||||
auto paddedType = RankedTensorType::get(
|
auto paddedType = RankedTensorType::get(
|
||||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0);
|
||||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||||
|
|||||||
@@ -176,10 +176,10 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
|||||||
let summary = "Copy a memory region within the same memory space";
|
let summary = "Copy a memory region within the same memory space";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
Index:$targetOffset,
|
||||||
|
Index:$sourceOffset,
|
||||||
PimTensor:$target,
|
PimTensor:$target,
|
||||||
PimTensor:$source,
|
PimTensor:$source,
|
||||||
I32Attr:$targetOffset,
|
|
||||||
I32Attr:$sourceOffset,
|
|
||||||
I32Attr:$size
|
I32Attr:$size
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -194,7 +194,9 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
|
`[` $targetOffset `,` $sourceOffset `]`
|
||||||
|
`(` $target `,` $source `)` attr-dict
|
||||||
|
`:` type($target) `,` type($source) `->` type($output)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,14 +19,15 @@ Value materializeContiguousInputMemRef(Value memrefValue, Location loc, Rewriter
|
|||||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||||
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||||
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
|
||||||
|
|
||||||
return PimMemCopyOp::create(rewriter,
|
return PimMemCopyOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
contiguousType,
|
contiguousType,
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
contiguousBuffer,
|
contiguousBuffer,
|
||||||
memrefValue,
|
memrefValue,
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
|
||||||
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|
||||||
add_public_tablegen_target(PimBufferizationIncGen)
|
|
||||||
|
|
||||||
add_pim_library(OMPimBufferization
|
add_pim_library(OMPimBufferization
|
||||||
PimBufferizationPass.cpp
|
PimBufferizationPass.cpp
|
||||||
BufferizationUtils.hpp
|
BufferizationUtils.hpp
|
||||||
@@ -15,9 +11,6 @@ add_pim_library(OMPimBufferization
|
|||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
PimBufferizationIncGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
PimOps
|
PimOps
|
||||||
|
|||||||
@@ -3,220 +3,360 @@
|
|||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
#include "ContiguityPatterns.hpp"
|
#include "ContiguityPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
|
struct ByteOffsetTerm {
|
||||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
Value value;
|
||||||
return false;
|
int64_t scale = 0;
|
||||||
|
};
|
||||||
|
|
||||||
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
|
struct ByteOffsetExpr {
|
||||||
|
int64_t constant = 0;
|
||||||
|
SmallVector<ByteOffsetTerm> terms;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CopyEndpointPlan {
|
||||||
|
Value base;
|
||||||
|
MemRefType originalType;
|
||||||
|
MemRefType baseType;
|
||||||
|
ByteOffsetExpr offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CopyLoopPlan {
|
||||||
|
SmallVector<int64_t> outerShape;
|
||||||
|
int64_t chunkBytes = 0;
|
||||||
|
ByteOffsetExpr targetBaseOffset;
|
||||||
|
ByteOffsetExpr sourceBaseOffset;
|
||||||
|
SmallVector<int64_t> targetOuterByteStrides;
|
||||||
|
SmallVector<int64_t> sourceOuterByteStrides;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CopyRewritePlan {
|
||||||
|
enum class Kind {
|
||||||
|
Direct,
|
||||||
|
Loop
|
||||||
|
} kind = Kind::Direct;
|
||||||
|
CopyEndpointPlan target;
|
||||||
|
CopyEndpointPlan source;
|
||||||
|
int64_t directBytes = 0;
|
||||||
|
CopyLoopPlan loop;
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool isViewLike(Value value) {
|
||||||
|
Operation* defOp = value.getDefiningOp();
|
||||||
|
return defOp
|
||||||
|
&& isa<memref::SubViewOp,
|
||||||
|
memref::ReinterpretCastOp,
|
||||||
|
memref::CollapseShapeOp,
|
||||||
|
memref::ExpandShapeOp,
|
||||||
|
memref::CastOp>(defOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
template <typename CopyOp>
|
||||||
if (extraOffset == 0)
|
static bool isNormalizedCopyLikeOp(CopyOp copyOp, Value target, Value source, Value targetOffset, Value sourceOffset) {
|
||||||
return baseOffset;
|
auto targetType = dyn_cast<MemRefType>(target.getType());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
return targetType && sourceType && !isViewLike(target) && !isViewLike(source) && targetOffset.getType().isIndex()
|
||||||
|
&& sourceOffset.getType().isIndex() && copyOp.getSize() > 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
static void appendTerm(ByteOffsetExpr& expr, Value value, int64_t scale) {
|
||||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
if (scale != 0)
|
||||||
assert(integerAttr && "expected integer offset attribute");
|
expr.terms.push_back(ByteOffsetTerm {value, scale});
|
||||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
int64_t offset = 0;
|
||||||
|
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(strides, ShapedType::isDynamic))
|
||||||
|
return failure();
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
|
||||||
|
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
|
||||||
|
return failure();
|
||||||
|
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>>
|
||||||
|
inferLogicalCopyShape(MemRefType targetType, MemRefType sourceType, int64_t size) {
|
||||||
|
if (!targetType.hasStaticShape() || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (targetType.getElementType() != sourceType.getElementType() || targetType.getRank() != sourceType.getRank())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto targetBytes = getShapedByteSize(targetType);
|
||||||
|
auto sourceBytes = getShapedByteSize(sourceType);
|
||||||
|
if (failed(targetBytes) || failed(sourceBytes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool targetMatches = *targetBytes == size;
|
||||||
|
bool sourceMatches = *sourceBytes == size;
|
||||||
|
if (targetMatches && sourceMatches && targetType.getShape() != sourceType.getShape())
|
||||||
|
return failure();
|
||||||
|
if (targetMatches)
|
||||||
|
return SmallVector<int64_t>(targetType.getShape().begin(), targetType.getShape().end());
|
||||||
|
if (sourceMatches)
|
||||||
|
return SmallVector<int64_t>(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> getContiguousSuffixRank(MemRefType type, ArrayRef<int64_t> copyShape) {
|
||||||
|
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())
|
||||||
|
|| type.getRank() != static_cast<int64_t>(copyShape.size()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto strides = getStaticMemRefStrides(type);
|
||||||
|
if (failed(strides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t expectedStride = 1;
|
||||||
|
int64_t contiguousSuffixRank = 0;
|
||||||
|
for (int64_t dim = type.getRank() - 1; dim >= 0; --dim) {
|
||||||
|
if ((*strides)[dim] != expectedStride)
|
||||||
|
break;
|
||||||
|
++contiguousSuffixRank;
|
||||||
|
expectedStride *= copyShape[dim];
|
||||||
|
}
|
||||||
|
return contiguousSuffixRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<CopyEndpointPlan> analyzeCopyEndpoint(Value value, Value initialByteOffset, MemRefType logicalType) {
|
||||||
|
if (!logicalType.hasStaticShape() || !hasByteSizedElementType(logicalType.getElementType()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
CopyEndpointPlan endpoint;
|
||||||
|
endpoint.base = value;
|
||||||
|
endpoint.originalType = logicalType;
|
||||||
|
appendTerm(endpoint.offset, initialByteOffset, 1);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = endpoint.base.getDefiningOp<memref::CastOp>()) {
|
||||||
|
endpoint.base = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = endpoint.base.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
endpoint.base = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = endpoint.base.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
endpoint.base = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto reinterpretOp = endpoint.base.getDefiningOp<memref::ReinterpretCastOp>()) {
|
||||||
|
endpoint.base = reinterpretOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto subviewOp = endpoint.base.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape() || !hasByteSizedElementType(sourceType.getElementType()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||||
|
if (failed(sourceStrides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
for (auto [offset, stride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||||
|
int64_t byteScale = stride * elementByteWidth;
|
||||||
|
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||||
|
endpoint.offset.constant += cast<IntegerAttr>(attr).getInt() * byteScale;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
appendTerm(endpoint.offset, cast<Value>(offset), byteScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint.base = subviewOp.getSource();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value = cast<Value>(baseOffset);
|
endpoint.baseType = dyn_cast<MemRefType>(endpoint.base.getType());
|
||||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
|
if (!endpoint.baseType)
|
||||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
return failure();
|
||||||
|
return endpoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
static FailureOr<CopyRewritePlan>
|
||||||
ArrayRef<int64_t> outerIndices,
|
analyzeCopyRewrite(Value target, Value source, Value targetOffset, Value sourceOffset, int64_t size) {
|
||||||
Location loc,
|
auto targetType = dyn_cast<MemRefType>(target.getType());
|
||||||
PatternRewriter& rewriter) {
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
if (!targetType || !sourceType || size <= 0)
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
return failure();
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(info.offsets.size());
|
|
||||||
chunkSizes.reserve(info.sizes.size());
|
|
||||||
chunkStrides.reserve(info.strides.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
auto logicalCopyShape = inferLogicalCopyShape(targetType, sourceType, size);
|
||||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
if (failed(logicalCopyShape))
|
||||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
return failure();
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
auto targetPlan = analyzeCopyEndpoint(target, targetOffset, targetType);
|
||||||
|
auto sourcePlan = analyzeCopyEndpoint(source, sourceOffset, sourceType);
|
||||||
|
if (failed(targetPlan) || failed(sourcePlan))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto targetSuffixRank = getContiguousSuffixRank(targetType, *logicalCopyShape);
|
||||||
|
auto sourceSuffixRank = getContiguousSuffixRank(sourceType, *logicalCopyShape);
|
||||||
|
if (failed(targetSuffixRank) || failed(sourceSuffixRank))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
CopyRewritePlan plan;
|
||||||
|
plan.target = *targetPlan;
|
||||||
|
plan.source = *sourcePlan;
|
||||||
|
|
||||||
|
int64_t contiguousSuffixRank = std::min(*targetSuffixRank, *sourceSuffixRank);
|
||||||
|
if (contiguousSuffixRank == static_cast<int64_t>(logicalCopyShape->size())) {
|
||||||
|
plan.kind = CopyRewritePlan::Kind::Direct;
|
||||||
|
plan.directBytes = size;
|
||||||
|
return plan;
|
||||||
}
|
}
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
auto targetStrides = getStaticMemRefStrides(targetType);
|
||||||
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||||
|
if (failed(targetStrides) || failed(sourceStrides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(targetType.getElementType()));
|
||||||
|
plan.kind = CopyRewritePlan::Kind::Loop;
|
||||||
|
plan.loop.targetBaseOffset = plan.target.offset;
|
||||||
|
plan.loop.sourceBaseOffset = plan.source.offset;
|
||||||
|
plan.loop.outerShape.assign(logicalCopyShape->begin(), logicalCopyShape->end() - contiguousSuffixRank);
|
||||||
|
SmallVector<int64_t> chunkShape(logicalCopyShape->end() - contiguousSuffixRank, logicalCopyShape->end());
|
||||||
|
plan.loop.chunkBytes = getNumElements(chunkShape) * elementByteWidth;
|
||||||
|
for (int64_t stride : ArrayRef<int64_t>(*targetStrides).take_front(plan.loop.outerShape.size()))
|
||||||
|
plan.loop.targetOuterByteStrides.push_back(stride * elementByteWidth);
|
||||||
|
for (int64_t stride : ArrayRef<int64_t>(*sourceStrides).take_front(plan.loop.outerShape.size()))
|
||||||
|
plan.loop.sourceOuterByteStrides.push_back(stride * elementByteWidth);
|
||||||
|
if (plan.loop.chunkBytes <= 0)
|
||||||
|
return failure();
|
||||||
|
return plan;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
static Value createIndexConstant(PatternRewriter& rewriter, Operation* anchorOp, int64_t value) {
|
||||||
ArrayRef<int64_t> shape,
|
return getOrCreateIndexConstant(rewriter, anchorOp, value);
|
||||||
ArrayRef<int64_t> strides,
|
}
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
|
static Value addIndexValues(PatternRewriter& rewriter, Location loc, Value lhs, Value rhs) {
|
||||||
|
if (auto constant = getConstantIntValue(lhs); constant && *constant == 0)
|
||||||
|
return rhs;
|
||||||
|
if (auto constant = getConstantIntValue(rhs); constant && *constant == 0)
|
||||||
|
return lhs;
|
||||||
|
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value mulIndexValue(PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value value, int64_t scale) {
|
||||||
|
if (scale == 0)
|
||||||
|
return createIndexConstant(rewriter, anchorOp, 0);
|
||||||
|
if (scale == 1)
|
||||||
|
return value;
|
||||||
|
Value scaleValue = createIndexConstant(rewriter, anchorOp, scale);
|
||||||
|
return arith::MulIOp::create(rewriter, loc, value, scaleValue).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
materializeByteOffset(PatternRewriter& rewriter, Location loc, Operation* anchorOp, const ByteOffsetExpr& expr) {
|
||||||
|
Value result = createIndexConstant(rewriter, anchorOp, expr.constant);
|
||||||
|
for (const ByteOffsetTerm& term : expr.terms)
|
||||||
|
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, term.value, term.scale));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> materializeDelinearizedIndices(
|
||||||
|
PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value linearIndex, ArrayRef<int64_t> shape) {
|
||||||
SmallVector<Value> indices;
|
SmallVector<Value> indices;
|
||||||
indices.reserve(shape.size());
|
if (shape.empty())
|
||||||
|
return indices;
|
||||||
|
|
||||||
|
auto rowMajorStrides = computeRowMajorStrides(shape);
|
||||||
Value remaining = linearIndex;
|
Value remaining = linearIndex;
|
||||||
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
for (auto [dim, stride] : llvm::enumerate(rowMajorStrides)) {
|
||||||
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
|
if (dim + 1 == rowMajorStrides.size()) {
|
||||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
indices.push_back(remaining);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Value strideValue = createIndexConstant(rewriter, anchorOp, stride);
|
||||||
|
Value index = arith::DivUIOp::create(rewriter, loc, remaining, strideValue);
|
||||||
indices.push_back(index);
|
indices.push_back(index);
|
||||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
remaining = arith::RemUIOp::create(rewriter, loc, remaining, strideValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
static Value materializeOuterByteOffset(PatternRewriter& rewriter,
|
||||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
Location loc,
|
||||||
auto integerAttr = cast<IntegerAttr>(attr);
|
Operation* anchorOp,
|
||||||
if (integerAttr.getInt() == 0)
|
const ByteOffsetExpr& baseOffset,
|
||||||
return extraOffset;
|
ArrayRef<Value> outerIndices,
|
||||||
|
ArrayRef<int64_t> outerByteStrides) {
|
||||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
|
Value result = materializeByteOffset(rewriter, loc, anchorOp, baseOffset);
|
||||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
for (auto [index, stride] : llvm::zip_equal(outerIndices, outerByteStrides))
|
||||||
}
|
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, index, stride));
|
||||||
|
return result;
|
||||||
auto value = cast<Value>(baseOffset);
|
|
||||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
|
||||||
ArrayRef<Value> outerIndices,
|
|
||||||
Location loc,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(info.offsets.size());
|
|
||||||
chunkSizes.reserve(info.sizes.size());
|
|
||||||
chunkStrides.reserve(info.strides.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
|
||||||
if (dim + 1 < info.sizes.size()) {
|
|
||||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
|
||||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
chunkOffsets.push_back(info.offsets[dim]);
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
|
||||||
}
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildContiguousChunk(
|
|
||||||
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(copyShape.size());
|
|
||||||
chunkSizes.reserve(copyShape.size());
|
|
||||||
chunkStrides.reserve(copyShape.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
|
||||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename CopyOp, typename CreateCopyOp>
|
template <typename CopyOp, typename CreateCopyOp>
|
||||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
|
||||||
Value dst,
|
Value target,
|
||||||
Value src,
|
Value source,
|
||||||
int64_t dstOffset,
|
Value targetOffset,
|
||||||
int64_t srcOffset,
|
Value sourceOffset,
|
||||||
int64_t size,
|
Value replacementValue,
|
||||||
bool allowLoopRewrite,
|
CreateCopyOp createCopyOp,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter) {
|
||||||
CreateCopyOp createCopyOp) {
|
if (isNormalizedCopyLikeOp(copyOp, target, source, targetOffset, sourceOffset))
|
||||||
auto srcSubview = getStaticSubviewInfo(src);
|
|
||||||
auto dstSubview = getStaticSubviewInfo(dst);
|
|
||||||
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
|
|
||||||
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
|
|
||||||
if (!splitSrc && !splitDst)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
auto plan = analyzeCopyRewrite(target, source, targetOffset, sourceOffset, copyOp.getSize());
|
||||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
if (failed(plan))
|
||||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
if (sourceType.getElementType() != dstType.getElementType())
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
Location loc = copyOp.getLoc();
|
||||||
return failure();
|
Operation* anchorOp = copyOp.getOperation();
|
||||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
if (plan->kind == CopyRewritePlan::Kind::Direct) {
|
||||||
return failure();
|
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
|
||||||
|
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
|
||||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
auto newCopyOp = createCopyOp(loc,
|
||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
plan->target.base,
|
||||||
return failure();
|
plan->source.base,
|
||||||
|
newTargetOffset,
|
||||||
if (!hasByteSizedElementType(sourceType.getElementType()))
|
newSourceOffset,
|
||||||
return failure();
|
static_cast<int32_t>(plan->directBytes));
|
||||||
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||||
|
rewriter.replaceOp(copyOp, replacementValue);
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
||||||
if (size != totalBytes)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
||||||
if (sliceBytes <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
||||||
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
|
|
||||||
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
|
|
||||||
|
|
||||||
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
|
||||||
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
|
||||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
|
|
||||||
&& (splitDst || dstShapeMatchesCopyShape)) {
|
|
||||||
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
|
||||||
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
|
|
||||||
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
|
|
||||||
|
|
||||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
|
||||||
|
|
||||||
SmallVector<Value> outerIndices =
|
|
||||||
outerShape.empty() ? SmallVector<Value> {}
|
|
||||||
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
|
||||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
||||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
||||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
||||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
||||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(copyOp);
|
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
|
||||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
|
||||||
SmallVector<int64_t> outerIndices =
|
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
|
||||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
|
||||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
|
||||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
|
||||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
|
||||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
SmallVector<Value> outerIndices =
|
||||||
|
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
|
||||||
|
Value loopTargetOffset = materializeOuterByteOffset(
|
||||||
|
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
|
||||||
|
Value loopSourceOffset = materializeOuterByteOffset(
|
||||||
|
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
|
||||||
|
auto newCopyOp = createCopyOp(loc,
|
||||||
|
plan->target.base,
|
||||||
|
plan->source.base,
|
||||||
|
loopTargetOffset,
|
||||||
|
loopSourceOffset,
|
||||||
|
static_cast<int32_t>(plan->loop.chunkBytes));
|
||||||
|
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
rewriter.replaceOp(copyOp, replacementValue);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,34 +364,24 @@ struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyO
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
return rewriteCopyLikeOp(
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
copyOp,
|
||||||
copyOp.getTarget(),
|
copyOp.getTarget(),
|
||||||
copyOp.getSource(),
|
copyOp.getSource(),
|
||||||
copyOp.getTargetOffset(),
|
copyOp.getTargetOffset(),
|
||||||
copyOp.getSourceOffset(),
|
copyOp.getSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getTarget(),
|
||||||
/*allowLoopRewrite=*/true,
|
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||||
rewriter,
|
return pim::PimMemCopyOp::create(rewriter,
|
||||||
[&](
|
loc,
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
cast<MemRefType>(target.getType()),
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
targetOffset,
|
||||||
copyOp.getLoc(),
|
sourceOffset,
|
||||||
resultType,
|
target,
|
||||||
dst,
|
source,
|
||||||
src,
|
rewriter.getI32IntegerAttr(size));
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
},
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
rewriter);
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -259,38 +389,24 @@ struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyH
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
return rewriteCopyLikeOp(
|
||||||
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
|
||||||
if (failed(dstOffset) || failed(srcOffset))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
copyOp,
|
||||||
copyOp.getDeviceTarget(),
|
copyOp.getDeviceTarget(),
|
||||||
copyOp.getHostSource(),
|
copyOp.getHostSource(),
|
||||||
*dstOffset,
|
copyOp.getDeviceTargetOffset(),
|
||||||
*srcOffset,
|
copyOp.getHostSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getDeviceTarget(),
|
||||||
/*allowLoopRewrite=*/true,
|
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||||
rewriter,
|
return pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
[&](
|
loc,
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
cast<MemRefType>(target.getType()),
|
||||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
targetOffset,
|
||||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
sourceOffset,
|
||||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
target,
|
||||||
copyOp.getLoc(),
|
source,
|
||||||
resultType,
|
rewriter.getI32IntegerAttr(size));
|
||||||
dstOffsetValue,
|
},
|
||||||
srcOffsetValue,
|
rewriter);
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -298,43 +414,43 @@ struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopy
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
return rewriteCopyLikeOp(
|
||||||
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
|
||||||
if (failed(dstOffset) || failed(srcOffset))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
copyOp,
|
||||||
copyOp.getHostTarget(),
|
copyOp.getHostTarget(),
|
||||||
copyOp.getDeviceSource(),
|
copyOp.getDeviceSource(),
|
||||||
*dstOffset,
|
copyOp.getHostTargetOffset(),
|
||||||
*srcOffset,
|
copyOp.getDeviceSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getHostTarget(),
|
||||||
/*allowLoopRewrite=*/false,
|
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||||
rewriter,
|
return pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
[&](
|
loc,
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
cast<MemRefType>(target.getType()),
|
||||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
targetOffset,
|
||||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
sourceOffset,
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
target,
|
||||||
copyOp.getLoc(),
|
source,
|
||||||
resultType,
|
rewriter.getI32IntegerAttr(size));
|
||||||
dstOffsetValue,
|
},
|
||||||
srcOffsetValue,
|
rewriter);
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyOp op) {
|
||||||
|
return isNormalizedCopyLikeOp(op, op.getTarget(), op.getSource(), op.getTargetOffset(), op.getSourceOffset());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op) {
|
||||||
|
return isNormalizedCopyLikeOp(
|
||||||
|
op, op.getDeviceTarget(), op.getHostSource(), op.getDeviceTargetOffset(), op.getHostSourceOffset());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op) {
|
||||||
|
return isNormalizedCopyLikeOp(
|
||||||
|
op, op.getHostTarget(), op.getDeviceSource(), op.getHostTargetOffset(), op.getDeviceSourceOffset());
|
||||||
|
}
|
||||||
|
|
||||||
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
|
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
|
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
|
|||||||
@@ -2,8 +2,14 @@
|
|||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyOp op);
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op);
|
||||||
|
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op);
|
||||||
|
|
||||||
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
|
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
} // namespace onnx_mlir::pim
|
} // namespace onnx_mlir::pim
|
||||||
|
|||||||
@@ -101,10 +101,10 @@ struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInt
|
|||||||
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||||
memCopyOp,
|
memCopyOp,
|
||||||
targetOpt->getType(),
|
targetOpt->getType(),
|
||||||
|
memCopyOp.getTargetOffset(),
|
||||||
|
memCopyOp.getSourceOffset(),
|
||||||
*targetOpt,
|
*targetOpt,
|
||||||
*sourceOpt,
|
*sourceOpt,
|
||||||
memCopyOp.getTargetOffsetAttr(),
|
|
||||||
memCopyOp.getSourceOffsetAttr(),
|
|
||||||
memCopyOp.getSizeAttr());
|
memCopyOp.getSizeAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
#ifndef PIM_BUFFERIZATION
|
|
||||||
#define PIM_BUFFERIZATION
|
|
||||||
|
|
||||||
#ifndef OP_BASE
|
|
||||||
include "mlir/IR/PatternBase.td"
|
|
||||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
|
||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
|
||||||
#endif // OP_BASE
|
|
||||||
|
|
||||||
def memrefCopyToPimMemCopyOp : Pat<
|
|
||||||
(CopyOp $src, $dst),
|
|
||||||
(PimMemCopyOp $dst, $src,
|
|
||||||
ConstantAttr<I32Attr, "0">,
|
|
||||||
ConstantAttr<I32Attr, "0">,
|
|
||||||
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
|
||||||
(returnType $dst))
|
|
||||||
>;
|
|
||||||
|
|
||||||
#endif // PIM_BUFFERIZATION
|
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Rewrite/PatternApplicator.h"
|
||||||
|
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
@@ -27,7 +27,34 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||||
|
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||||
|
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != targetType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||||
|
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
|
||||||
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
copyOp.getTarget().getType(),
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
|
copyOp.getTarget(),
|
||||||
|
copyOp.getSource(),
|
||||||
|
sizeAttr);
|
||||||
|
rewriter.eraseOp(copyOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
@@ -44,6 +71,14 @@ private:
|
|||||||
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
||||||
|
if (!op || !op->getBlock())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(op);
|
||||||
|
return applicator.matchAndRewrite(op, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
@@ -63,35 +98,60 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
ConversionTarget target(*ctx);
|
RewritePatternSet memrefCopyPatterns(ctx);
|
||||||
target.addLegalDialect<PimDialect>();
|
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
|
||||||
|
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
|
||||||
|
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
|
||||||
|
memrefCopyApplicator.applyDefaultCostModel();
|
||||||
|
PatternRewriter rewriter(ctx);
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
SmallVector<memref::CopyOp> copyWorklist;
|
||||||
populateWithGenerated(patterns);
|
moduleOp.walk([&](memref::CopyOp copyOp) {
|
||||||
|
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
|
copyWorklist.push_back(copyOp);
|
||||||
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
|
|
||||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
|
||||||
bool hasFailed = false;
|
|
||||||
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
|
|
||||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
|
||||||
return WalkResult::advance();
|
|
||||||
if (failed(applyPartialConversion(op, target, frozenPatterns)))
|
|
||||||
hasFailed = true;
|
|
||||||
return WalkResult::skip();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
bool hasFailed = false;
|
||||||
|
for (memref::CopyOp copyOp : copyWorklist) {
|
||||||
|
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
|
||||||
|
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
|
||||||
|
hasFailed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (hasFailed) {
|
if (hasFailed) {
|
||||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
|
|
||||||
// Later PIM codegen passes may verify this invariant but must not repair it.
|
|
||||||
RewritePatternSet contiguityPatterns(ctx);
|
RewritePatternSet contiguityPatterns(ctx);
|
||||||
populatePimContiguityNormalizationPatterns(contiguityPatterns);
|
populatePimContiguityNormalizationPatterns(contiguityPatterns);
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
|
FrozenRewritePatternSet frozenContiguityPatterns(std::move(contiguityPatterns));
|
||||||
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
|
PatternApplicator contiguityApplicator(frozenContiguityPatterns);
|
||||||
|
contiguityApplicator.applyDefaultCostModel();
|
||||||
|
|
||||||
|
SmallVector<Operation*> contiguityWorklist;
|
||||||
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
if (isa<pim::PimMemCopyOp, pim::PimMemCopyHostToDevOp, pim::PimMemCopyDevToHostOp>(op))
|
||||||
|
contiguityWorklist.push_back(op);
|
||||||
|
});
|
||||||
|
|
||||||
|
hasFailed = false;
|
||||||
|
for (Operation* op : contiguityWorklist) {
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||||
|
continue;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||||
|
continue;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (failed(applyPatternsOnce(op, contiguityApplicator, rewriter))) {
|
||||||
|
op->emitOpError("failed to normalize PIM copy contiguity");
|
||||||
|
hasFailed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasFailed) {
|
||||||
|
moduleOp.emitError("failed to normalize PIM copy contiguity during bufferization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -138,16 +198,28 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
|
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(memCopyOp)) {
|
||||||
|
memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
verifyOperand(memCopyOp.getTarget(), 0);
|
verifyOperand(memCopyOp.getTarget(), 0);
|
||||||
verifyOperand(memCopyOp.getSource(), 1);
|
verifyOperand(memCopyOp.getSource(), 1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(loadOp)) {
|
||||||
|
loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
verifyOperand(loadOp.getDeviceTarget(), 2);
|
verifyOperand(loadOp.getDeviceTarget(), 2);
|
||||||
verifyOperand(loadOp.getHostSource(), 3);
|
verifyOperand(loadOp.getHostSource(), 3);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(storeOp)) {
|
||||||
|
storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
verifyOperand(storeOp.getHostTarget(), 2);
|
verifyOperand(storeOp.getHostTarget(), 2);
|
||||||
verifyOperand(storeOp.getDeviceSource(), 3);
|
verifyOperand(storeOp.getDeviceSource(), 3);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -121,13 +121,14 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
|
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
|
||||||
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
initType,
|
initType,
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
mapOp.getInit(),
|
mapOp.getInit(),
|
||||||
getGlobalOp.getResult(),
|
getGlobalOp.getResult(),
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||||
rewriter.eraseOp(mapOp);
|
rewriter.eraseOp(mapOp);
|
||||||
return success();
|
return success();
|
||||||
@@ -487,7 +488,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
if (!allocType || !allocType.hasStaticShape())
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
auto targetOffset = resolveIndexValue(copyOp.getTargetOffset());
|
||||||
|
auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset());
|
||||||
|
if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -333,6 +334,12 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(storeOp)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
||||||
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
@@ -343,6 +350,12 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(loadOp)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
||||||
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
@@ -351,6 +364,22 @@ private:
|
|||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op)) {
|
||||||
|
if (!pim::isNormalizedCopyOp(copyOp)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge))
|
||||||
|
|| failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||||
|
});
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user