compact memory contiguity with for loops
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -69,6 +69,16 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
@@ -539,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
@@ -651,12 +663,16 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
constantByteOffset +=
|
||||
linearizeIndex(staticOffsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr offsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
@@ -665,7 +681,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), sourceStrides)) {
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||
CompiledIndexExpr operandExpr;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||
CompiledIndexExprNode expr;
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -15,16 +13,12 @@ namespace onnx_mlir {
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||
return current->getBlock();
|
||||
|
||||
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
||||
return &funcOp.getBody().front();
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
||||
return moduleOp.getBody();
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
|
||||
@@ -235,6 +235,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
||||
if (failed(compiledExpr)) {
|
||||
errs() << "Failed to compile contiguous address for value: ";
|
||||
value.print(errs());
|
||||
errs() << " : " << value.getType();
|
||||
errs() << "\n";
|
||||
llvm_unreachable("Failed to compile contiguous address");
|
||||
}
|
||||
@@ -245,6 +246,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
||||
if (failed(resolvedAddress)) {
|
||||
errs() << "Failed to evaluate contiguous address for value: ";
|
||||
value.print(errs());
|
||||
errs() << " : " << value.getType();
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
@@ -493,11 +495,15 @@ void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const Static
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto targetOffset = indexOf(lmvOp.getTargetOffset(), knowledge);
|
||||
auto sourceOffset = indexOf(lmvOp.getSourceOffset(), knowledge);
|
||||
assert(succeeded(targetOffset) && succeeded(sourceOffset)
|
||||
&& "pim.memcp offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("lmv",
|
||||
addressOf(lmvOp.getTarget(), knowledge),
|
||||
lmvOp.getTargetOffset(),
|
||||
*targetOffset,
|
||||
addressOf(lmvOp.getSource(), knowledge),
|
||||
lmvOp.getSourceOffset(),
|
||||
*sourceOffset,
|
||||
lmvOp.getSize(),
|
||||
"len");
|
||||
}
|
||||
|
||||
@@ -101,9 +101,9 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
Value zeroIndex = getOrCreateIndexConstant(constantFolder, zeroed.getDefiningOp(), 0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, zeroed, vector, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
|
||||
@@ -176,10 +176,10 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region within the same memory space";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$targetOffset,
|
||||
Index:$sourceOffset,
|
||||
PimTensor:$target,
|
||||
PimTensor:$source,
|
||||
I32Attr:$targetOffset,
|
||||
I32Attr:$sourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -194,7 +194,9 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $target `,` $source `)` attr-dict `:` `(` type($target) `,` type($source) `)` `->` type($output)
|
||||
`[` $targetOffset `,` $sourceOffset `]`
|
||||
`(` $target `,` $source `)` attr-dict
|
||||
`:` type($target) `,` type($source) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -19,14 +19,15 @@ Value materializeContiguousInputMemRef(Value memrefValue, Location loc, Rewriter
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
set(LLVM_TARGET_DEFINITIONS PimBufferization.td)
|
||||
mlir_tablegen(PimBufferization.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(PimBufferizationIncGen)
|
||||
|
||||
add_pim_library(OMPimBufferization
|
||||
PimBufferizationPass.cpp
|
||||
BufferizationUtils.hpp
|
||||
@@ -15,9 +11,6 @@ add_pim_library(OMPimBufferization
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
DEPENDS
|
||||
PimBufferizationIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
|
||||
@@ -3,220 +3,360 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
namespace {
|
||||
|
||||
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
|
||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
struct ByteOffsetTerm {
|
||||
Value value;
|
||||
int64_t scale = 0;
|
||||
};
|
||||
|
||||
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
|
||||
struct ByteOffsetExpr {
|
||||
int64_t constant = 0;
|
||||
SmallVector<ByteOffsetTerm> terms;
|
||||
};
|
||||
|
||||
struct CopyEndpointPlan {
|
||||
Value base;
|
||||
MemRefType originalType;
|
||||
MemRefType baseType;
|
||||
ByteOffsetExpr offset;
|
||||
};
|
||||
|
||||
struct CopyLoopPlan {
|
||||
SmallVector<int64_t> outerShape;
|
||||
int64_t chunkBytes = 0;
|
||||
ByteOffsetExpr targetBaseOffset;
|
||||
ByteOffsetExpr sourceBaseOffset;
|
||||
SmallVector<int64_t> targetOuterByteStrides;
|
||||
SmallVector<int64_t> sourceOuterByteStrides;
|
||||
};
|
||||
|
||||
struct CopyRewritePlan {
|
||||
enum class Kind {
|
||||
Direct,
|
||||
Loop
|
||||
} kind = Kind::Direct;
|
||||
CopyEndpointPlan target;
|
||||
CopyEndpointPlan source;
|
||||
int64_t directBytes = 0;
|
||||
CopyLoopPlan loop;
|
||||
};
|
||||
|
||||
static bool isViewLike(Value value) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
return defOp
|
||||
&& isa<memref::SubViewOp,
|
||||
memref::ReinterpretCastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
memref::CastOp>(defOp);
|
||||
}
|
||||
|
||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||
if (extraOffset == 0)
|
||||
return baseOffset;
|
||||
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(integerAttr && "expected integer offset attribute");
|
||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||
template <typename CopyOp>
|
||||
static bool isNormalizedCopyLikeOp(CopyOp copyOp, Value target, Value source, Value targetOffset, Value sourceOffset) {
|
||||
auto targetType = dyn_cast<MemRefType>(target.getType());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
return targetType && sourceType && !isViewLike(target) && !isViewLike(source) && targetOffset.getType().isIndex()
|
||||
&& sourceOffset.getType().isIndex() && copyOp.getSize() > 0;
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||
static void appendTerm(ByteOffsetExpr& expr, Value value, int64_t scale) {
|
||||
if (scale != 0)
|
||||
expr.terms.push_back(ByteOffsetTerm {value, scale});
|
||||
}
|
||||
|
||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return failure();
|
||||
if (llvm::any_of(strides, ShapedType::isDynamic))
|
||||
return failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
|
||||
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
|
||||
return failure();
|
||||
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
||||
}
|
||||
|
||||
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
||||
ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> strides,
|
||||
PatternRewriter& rewriter) {
|
||||
static FailureOr<SmallVector<int64_t>>
|
||||
inferLogicalCopyShape(MemRefType targetType, MemRefType sourceType, int64_t size) {
|
||||
if (!targetType.hasStaticShape() || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (targetType.getElementType() != sourceType.getElementType() || targetType.getRank() != sourceType.getRank())
|
||||
return failure();
|
||||
|
||||
auto targetBytes = getShapedByteSize(targetType);
|
||||
auto sourceBytes = getShapedByteSize(sourceType);
|
||||
if (failed(targetBytes) || failed(sourceBytes))
|
||||
return failure();
|
||||
|
||||
bool targetMatches = *targetBytes == size;
|
||||
bool sourceMatches = *sourceBytes == size;
|
||||
if (targetMatches && sourceMatches && targetType.getShape() != sourceType.getShape())
|
||||
return failure();
|
||||
if (targetMatches)
|
||||
return SmallVector<int64_t>(targetType.getShape().begin(), targetType.getShape().end());
|
||||
if (sourceMatches)
|
||||
return SmallVector<int64_t>(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> getContiguousSuffixRank(MemRefType type, ArrayRef<int64_t> copyShape) {
|
||||
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType())
|
||||
|| type.getRank() != static_cast<int64_t>(copyShape.size()))
|
||||
return failure();
|
||||
|
||||
auto strides = getStaticMemRefStrides(type);
|
||||
if (failed(strides))
|
||||
return failure();
|
||||
|
||||
int64_t expectedStride = 1;
|
||||
int64_t contiguousSuffixRank = 0;
|
||||
for (int64_t dim = type.getRank() - 1; dim >= 0; --dim) {
|
||||
if ((*strides)[dim] != expectedStride)
|
||||
break;
|
||||
++contiguousSuffixRank;
|
||||
expectedStride *= copyShape[dim];
|
||||
}
|
||||
return contiguousSuffixRank;
|
||||
}
|
||||
|
||||
static FailureOr<CopyEndpointPlan> analyzeCopyEndpoint(Value value, Value initialByteOffset, MemRefType logicalType) {
|
||||
if (!logicalType.hasStaticShape() || !hasByteSizedElementType(logicalType.getElementType()))
|
||||
return failure();
|
||||
|
||||
CopyEndpointPlan endpoint;
|
||||
endpoint.base = value;
|
||||
endpoint.originalType = logicalType;
|
||||
appendTerm(endpoint.offset, initialByteOffset, 1);
|
||||
|
||||
while (true) {
|
||||
if (auto castOp = endpoint.base.getDefiningOp<memref::CastOp>()) {
|
||||
endpoint.base = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = endpoint.base.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
endpoint.base = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = endpoint.base.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
endpoint.base = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto reinterpretOp = endpoint.base.getDefiningOp<memref::ReinterpretCastOp>()) {
|
||||
endpoint.base = reinterpretOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto subviewOp = endpoint.base.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
break;
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !hasByteSizedElementType(sourceType.getElementType()))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return failure();
|
||||
|
||||
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||
for (auto [offset, stride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||
int64_t byteScale = stride * elementByteWidth;
|
||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||
endpoint.offset.constant += cast<IntegerAttr>(attr).getInt() * byteScale;
|
||||
continue;
|
||||
}
|
||||
appendTerm(endpoint.offset, cast<Value>(offset), byteScale);
|
||||
}
|
||||
|
||||
endpoint.base = subviewOp.getSource();
|
||||
}
|
||||
|
||||
endpoint.baseType = dyn_cast<MemRefType>(endpoint.base.getType());
|
||||
if (!endpoint.baseType)
|
||||
return failure();
|
||||
return endpoint;
|
||||
}
|
||||
|
||||
static FailureOr<CopyRewritePlan>
|
||||
analyzeCopyRewrite(Value target, Value source, Value targetOffset, Value sourceOffset, int64_t size) {
|
||||
auto targetType = dyn_cast<MemRefType>(target.getType());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
if (!targetType || !sourceType || size <= 0)
|
||||
return failure();
|
||||
|
||||
auto logicalCopyShape = inferLogicalCopyShape(targetType, sourceType, size);
|
||||
if (failed(logicalCopyShape))
|
||||
return failure();
|
||||
|
||||
auto targetPlan = analyzeCopyEndpoint(target, targetOffset, targetType);
|
||||
auto sourcePlan = analyzeCopyEndpoint(source, sourceOffset, sourceType);
|
||||
if (failed(targetPlan) || failed(sourcePlan))
|
||||
return failure();
|
||||
|
||||
auto targetSuffixRank = getContiguousSuffixRank(targetType, *logicalCopyShape);
|
||||
auto sourceSuffixRank = getContiguousSuffixRank(sourceType, *logicalCopyShape);
|
||||
if (failed(targetSuffixRank) || failed(sourceSuffixRank))
|
||||
return failure();
|
||||
|
||||
CopyRewritePlan plan;
|
||||
plan.target = *targetPlan;
|
||||
plan.source = *sourcePlan;
|
||||
|
||||
int64_t contiguousSuffixRank = std::min(*targetSuffixRank, *sourceSuffixRank);
|
||||
if (contiguousSuffixRank == static_cast<int64_t>(logicalCopyShape->size())) {
|
||||
plan.kind = CopyRewritePlan::Kind::Direct;
|
||||
plan.directBytes = size;
|
||||
return plan;
|
||||
}
|
||||
|
||||
auto targetStrides = getStaticMemRefStrides(targetType);
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(targetStrides) || failed(sourceStrides))
|
||||
return failure();
|
||||
|
||||
int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(targetType.getElementType()));
|
||||
plan.kind = CopyRewritePlan::Kind::Loop;
|
||||
plan.loop.targetBaseOffset = plan.target.offset;
|
||||
plan.loop.sourceBaseOffset = plan.source.offset;
|
||||
plan.loop.outerShape.assign(logicalCopyShape->begin(), logicalCopyShape->end() - contiguousSuffixRank);
|
||||
SmallVector<int64_t> chunkShape(logicalCopyShape->end() - contiguousSuffixRank, logicalCopyShape->end());
|
||||
plan.loop.chunkBytes = getNumElements(chunkShape) * elementByteWidth;
|
||||
for (int64_t stride : ArrayRef<int64_t>(*targetStrides).take_front(plan.loop.outerShape.size()))
|
||||
plan.loop.targetOuterByteStrides.push_back(stride * elementByteWidth);
|
||||
for (int64_t stride : ArrayRef<int64_t>(*sourceStrides).take_front(plan.loop.outerShape.size()))
|
||||
plan.loop.sourceOuterByteStrides.push_back(stride * elementByteWidth);
|
||||
if (plan.loop.chunkBytes <= 0)
|
||||
return failure();
|
||||
return plan;
|
||||
}
|
||||
|
||||
static Value createIndexConstant(PatternRewriter& rewriter, Operation* anchorOp, int64_t value) {
|
||||
return getOrCreateIndexConstant(rewriter, anchorOp, value);
|
||||
}
|
||||
|
||||
static Value addIndexValues(PatternRewriter& rewriter, Location loc, Value lhs, Value rhs) {
|
||||
if (auto constant = getConstantIntValue(lhs); constant && *constant == 0)
|
||||
return rhs;
|
||||
if (auto constant = getConstantIntValue(rhs); constant && *constant == 0)
|
||||
return lhs;
|
||||
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
|
||||
}
|
||||
|
||||
static Value mulIndexValue(PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value value, int64_t scale) {
|
||||
if (scale == 0)
|
||||
return createIndexConstant(rewriter, anchorOp, 0);
|
||||
if (scale == 1)
|
||||
return value;
|
||||
Value scaleValue = createIndexConstant(rewriter, anchorOp, scale);
|
||||
return arith::MulIOp::create(rewriter, loc, value, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static Value
|
||||
materializeByteOffset(PatternRewriter& rewriter, Location loc, Operation* anchorOp, const ByteOffsetExpr& expr) {
|
||||
Value result = createIndexConstant(rewriter, anchorOp, expr.constant);
|
||||
for (const ByteOffsetTerm& term : expr.terms)
|
||||
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, term.value, term.scale));
|
||||
return result;
|
||||
}
|
||||
|
||||
static SmallVector<Value> materializeDelinearizedIndices(
|
||||
PatternRewriter& rewriter, Location loc, Operation* anchorOp, Value linearIndex, ArrayRef<int64_t> shape) {
|
||||
SmallVector<Value> indices;
|
||||
indices.reserve(shape.size());
|
||||
if (shape.empty())
|
||||
return indices;
|
||||
|
||||
auto rowMajorStrides = computeRowMajorStrides(shape);
|
||||
Value remaining = linearIndex;
|
||||
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
|
||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
indices.push_back(index);
|
||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
for (auto [dim, stride] : llvm::enumerate(rowMajorStrides)) {
|
||||
if (dim + 1 == rowMajorStrides.size()) {
|
||||
indices.push_back(remaining);
|
||||
break;
|
||||
}
|
||||
Value strideValue = createIndexConstant(rewriter, anchorOp, stride);
|
||||
Value index = arith::DivUIOp::create(rewriter, loc, remaining, strideValue);
|
||||
indices.push_back(index);
|
||||
remaining = arith::RemUIOp::create(rewriter, loc, remaining, strideValue);
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = cast<IntegerAttr>(attr);
|
||||
if (integerAttr.getInt() == 0)
|
||||
return extraOffset;
|
||||
|
||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
|
||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||
}
|
||||
|
||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<Value> outerIndices,
|
||||
static Value materializeOuterByteOffset(PatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
if (dim + 1 < info.sizes.size()) {
|
||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
else {
|
||||
chunkOffsets.push_back(info.offsets[dim]);
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||
}
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static Value buildContiguousChunk(
|
||||
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(copyShape.size());
|
||||
chunkSizes.reserve(copyShape.size());
|
||||
chunkStrides.reserve(copyShape.size());
|
||||
|
||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
Operation* anchorOp,
|
||||
const ByteOffsetExpr& baseOffset,
|
||||
ArrayRef<Value> outerIndices,
|
||||
ArrayRef<int64_t> outerByteStrides) {
|
||||
Value result = materializeByteOffset(rewriter, loc, anchorOp, baseOffset);
|
||||
for (auto [index, stride] : llvm::zip_equal(outerIndices, outerByteStrides))
|
||||
result = addIndexValues(rewriter, loc, result, mulIndexValue(rewriter, loc, anchorOp, index, stride));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
bool allowLoopRewrite,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
|
||||
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
|
||||
if (!splitSrc && !splitDst)
|
||||
static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
|
||||
Value target,
|
||||
Value source,
|
||||
Value targetOffset,
|
||||
Value sourceOffset,
|
||||
Value replacementValue,
|
||||
CreateCopyOp createCopyOp,
|
||||
PatternRewriter& rewriter) {
|
||||
if (isNormalizedCopyLikeOp(copyOp, target, source, targetOffset, sourceOffset))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
auto plan = analyzeCopyRewrite(target, source, targetOffset, sourceOffset, copyOp.getSize());
|
||||
if (failed(plan))
|
||||
return failure();
|
||||
|
||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
if (!hasByteSizedElementType(sourceType.getElementType()))
|
||||
return failure();
|
||||
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (size != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
|
||||
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
|
||||
|
||||
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
|
||||
&& (splitDst || dstShapeMatchesCopyShape)) {
|
||||
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
|
||||
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
outerShape.empty() ? SmallVector<Value> {}
|
||||
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||
Location loc = copyOp.getLoc();
|
||||
Operation* anchorOp = copyOp.getOperation();
|
||||
if (plan->kind == CopyRewritePlan::Kind::Direct) {
|
||||
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
|
||||
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
|
||||
auto newCopyOp = createCopyOp(loc,
|
||||
plan->target.base,
|
||||
plan->source.base,
|
||||
newTargetOffset,
|
||||
newSourceOffset,
|
||||
static_cast<int32_t>(plan->directBytes));
|
||||
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||
rewriter.replaceOp(copyOp, replacementValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||
}
|
||||
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
|
||||
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
|
||||
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
|
||||
Value loopTargetOffset = materializeOuterByteOffset(
|
||||
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
|
||||
Value loopSourceOffset = materializeOuterByteOffset(
|
||||
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
|
||||
auto newCopyOp = createCopyOp(loc,
|
||||
plan->target.base,
|
||||
plan->source.base,
|
||||
loopTargetOffset,
|
||||
loopSourceOffset,
|
||||
static_cast<int32_t>(plan->loop.chunkBytes));
|
||||
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
rewriter.replaceOp(copyOp, replacementValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -224,34 +364,24 @@ struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyO
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
return rewriteCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
copyOp.getTargetOffset(),
|
||||
copyOp.getSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
||||
return success();
|
||||
copyOp.getTarget(),
|
||||
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
cast<MemRefType>(target.getType()),
|
||||
targetOffset,
|
||||
sourceOffset,
|
||||
target,
|
||||
source,
|
||||
rewriter.getI32IntegerAttr(size));
|
||||
},
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -259,38 +389,24 @@ struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyH
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
return rewriteCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getDeviceTarget(),
|
||||
copyOp.getHostSource(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffsetValue,
|
||||
srcOffsetValue,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
||||
return success();
|
||||
copyOp.getDeviceTargetOffset(),
|
||||
copyOp.getHostSourceOffset(),
|
||||
copyOp.getDeviceTarget(),
|
||||
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||
return pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
cast<MemRefType>(target.getType()),
|
||||
targetOffset,
|
||||
sourceOffset,
|
||||
target,
|
||||
source,
|
||||
rewriter.getI32IntegerAttr(size));
|
||||
},
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -298,43 +414,43 @@ struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopy
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
return rewriteCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getHostTarget(),
|
||||
copyOp.getDeviceSource(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/false,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffsetValue,
|
||||
srcOffsetValue,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
||||
return success();
|
||||
copyOp.getHostTargetOffset(),
|
||||
copyOp.getDeviceSourceOffset(),
|
||||
copyOp.getHostTarget(),
|
||||
[&](Location loc, Value target, Value source, Value targetOffset, Value sourceOffset, int32_t size) {
|
||||
return pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
cast<MemRefType>(target.getType()),
|
||||
targetOffset,
|
||||
sourceOffset,
|
||||
target,
|
||||
source,
|
||||
rewriter.getI32IntegerAttr(size));
|
||||
},
|
||||
rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyOp op) {
|
||||
return isNormalizedCopyLikeOp(op, op.getTarget(), op.getSource(), op.getTargetOffset(), op.getSourceOffset());
|
||||
}
|
||||
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op) {
|
||||
return isNormalizedCopyLikeOp(
|
||||
op, op.getDeviceTarget(), op.getHostSource(), op.getDeviceTargetOffset(), op.getHostSourceOffset());
|
||||
}
|
||||
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op) {
|
||||
return isNormalizedCopyLikeOp(
|
||||
op, op.getHostTarget(), op.getDeviceSource(), op.getHostTargetOffset(), op.getDeviceSourceOffset());
|
||||
}
|
||||
|
||||
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
|
||||
patterns.getContext());
|
||||
|
||||
@@ -2,8 +2,14 @@
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyOp op);
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyHostToDevOp op);
|
||||
bool isNormalizedCopyOp(pim::PimMemCopyDevToHostOp op);
|
||||
|
||||
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
|
||||
@@ -101,10 +101,10 @@ struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInt
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||
memCopyOp,
|
||||
targetOpt->getType(),
|
||||
memCopyOp.getTargetOffset(),
|
||||
memCopyOp.getSourceOffset(),
|
||||
*targetOpt,
|
||||
*sourceOpt,
|
||||
memCopyOp.getTargetOffsetAttr(),
|
||||
memCopyOp.getSourceOffsetAttr(),
|
||||
memCopyOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#ifndef PIM_BUFFERIZATION
|
||||
#define PIM_BUFFERIZATION
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "mlir/IR/PatternBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefOps.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def memrefCopyToPimMemCopyOp : Pat<
|
||||
(CopyOp $src, $dst),
|
||||
(PimMemCopyOp $dst, $src,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
ConstantAttr<I32Attr, "0">,
|
||||
(NativeCodeCall<"pim::getMemRefSizeInBytesAttr($_builder, $0)"> $src),
|
||||
(returnType $dst))
|
||||
>;
|
||||
|
||||
#endif // PIM_BUFFERIZATION
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
@@ -27,7 +27,34 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "Dialect/Pim/Transforms/Bufferization/PimBufferization.hpp.inc"
|
||||
struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSource().getType());
|
||||
auto targetType = dyn_cast<MemRefType>(copyOp.getTarget().getType());
|
||||
if (!sourceType || !targetType || !sourceType.hasStaticShape() || !targetType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return failure();
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
sizeAttr);
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||
@@ -44,6 +71,14 @@ private:
|
||||
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||
};
|
||||
|
||||
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
||||
if (!op || !op->getBlock())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
return applicator.matchAndRewrite(op, rewriter);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimBufferizationPass::runOnOperation() {
|
||||
@@ -63,35 +98,60 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect>();
|
||||
RewritePatternSet memrefCopyPatterns(ctx);
|
||||
memrefCopyPatterns.add<MemRefCopyToPimMemCopyPattern>(ctx);
|
||||
FrozenRewritePatternSet frozenMemrefCopyPatterns(std::move(memrefCopyPatterns));
|
||||
PatternApplicator memrefCopyApplicator(frozenMemrefCopyPatterns);
|
||||
memrefCopyApplicator.applyDefaultCostModel();
|
||||
PatternRewriter rewriter(ctx);
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
|
||||
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
bool hasFailed = false;
|
||||
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
|
||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||
return WalkResult::advance();
|
||||
if (failed(applyPartialConversion(op, target, frozenPatterns)))
|
||||
hasFailed = true;
|
||||
return WalkResult::skip();
|
||||
SmallVector<memref::CopyOp> copyWorklist;
|
||||
moduleOp.walk([&](memref::CopyOp copyOp) {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>() || copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
copyWorklist.push_back(copyOp);
|
||||
});
|
||||
|
||||
bool hasFailed = false;
|
||||
for (memref::CopyOp copyOp : copyWorklist) {
|
||||
if (failed(applyPatternsOnce(copyOp, memrefCopyApplicator, rewriter))) {
|
||||
copyOp.emitOpError("failed to lower memref.copy inside PIM core body");
|
||||
hasFailed = true;
|
||||
}
|
||||
}
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
|
||||
// Later PIM codegen passes may verify this invariant but must not repair it.
|
||||
RewritePatternSet contiguityPatterns(ctx);
|
||||
populatePimContiguityNormalizationPatterns(contiguityPatterns);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
|
||||
FrozenRewritePatternSet frozenContiguityPatterns(std::move(contiguityPatterns));
|
||||
PatternApplicator contiguityApplicator(frozenContiguityPatterns);
|
||||
contiguityApplicator.applyDefaultCostModel();
|
||||
|
||||
SmallVector<Operation*> contiguityWorklist;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (isa<pim::PimMemCopyOp, pim::PimMemCopyHostToDevOp, pim::PimMemCopyDevToHostOp>(op))
|
||||
contiguityWorklist.push_back(op);
|
||||
});
|
||||
|
||||
hasFailed = false;
|
||||
for (Operation* op : contiguityWorklist) {
|
||||
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||
continue;
|
||||
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||
continue;
|
||||
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op); copyOp && pim::isNormalizedCopyOp(copyOp))
|
||||
continue;
|
||||
|
||||
if (failed(applyPatternsOnce(op, contiguityApplicator, rewriter))) {
|
||||
op->emitOpError("failed to normalize PIM copy contiguity");
|
||||
hasFailed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to normalize PIM copy contiguity during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -138,16 +198,28 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
||||
};
|
||||
|
||||
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(memCopyOp)) {
|
||||
memCopyOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(memCopyOp.getTarget(), 0);
|
||||
verifyOperand(memCopyOp.getSource(), 1);
|
||||
return;
|
||||
}
|
||||
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(loadOp)) {
|
||||
loadOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(loadOp.getDeviceTarget(), 2);
|
||||
verifyOperand(loadOp.getHostSource(), 3);
|
||||
return;
|
||||
}
|
||||
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(storeOp)) {
|
||||
storeOp.emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
hasFailure = true;
|
||||
}
|
||||
verifyOperand(storeOp.getHostTarget(), 2);
|
||||
verifyOperand(storeOp.getDeviceSource(), 3);
|
||||
return;
|
||||
|
||||
@@ -121,13 +121,14 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
auto sizeInBytes = getShapedTypeSizeInBytes(initType);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
@@ -487,7 +488,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0)
|
||||
auto targetOffset = resolveIndexValue(copyOp.getTargetOffset());
|
||||
auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset());
|
||||
if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0)
|
||||
return failure();
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -333,6 +334,12 @@ private:
|
||||
}
|
||||
|
||||
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(storeOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
@@ -343,6 +350,12 @@ private:
|
||||
}
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(loadOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
@@ -351,6 +364,22 @@ private:
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(copyOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user