centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
This commit is contained in:
@@ -144,32 +144,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
@@ -10,8 +11,8 @@ using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
@@ -30,7 +31,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
mlir::Value
|
||||
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
|
||||
mlir::Value value,
|
||||
|
||||
@@ -6,6 +6,8 @@ add_pim_library(OMPimBufferization
|
||||
PimBufferizationPass.cpp
|
||||
BufferizationUtils.hpp
|
||||
BufferizationUtils.cpp
|
||||
ContiguityPatterns.hpp
|
||||
ContiguityPatterns.cpp
|
||||
OpBufferizationInterfaces.hpp
|
||||
OpBufferizationInterfaces.cpp
|
||||
Common.hpp
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
namespace {
|
||||
|
||||
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
|
||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
|
||||
}
|
||||
|
||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||
if (extraOffset == 0)
|
||||
return baseOffset;
|
||||
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(integerAttr && "expected integer offset attribute");
|
||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||
}
|
||||
|
||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
||||
ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> strides,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<Value> indices;
|
||||
indices.reserve(shape.size());
|
||||
|
||||
Value remaining = linearIndex;
|
||||
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
|
||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
indices.push_back(index);
|
||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = cast<IntegerAttr>(attr);
|
||||
if (integerAttr.getInt() == 0)
|
||||
return extraOffset;
|
||||
|
||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
|
||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||
}
|
||||
|
||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<Value> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
if (dim + 1 < info.sizes.size()) {
|
||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
else {
|
||||
chunkOffsets.push_back(info.offsets[dim]);
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||
}
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static Value buildContiguousChunk(
|
||||
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(copyShape.size());
|
||||
chunkSizes.reserve(copyShape.size());
|
||||
chunkStrides.reserve(copyShape.size());
|
||||
|
||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
bool allowLoopRewrite,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
|
||||
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
if (!hasByteSizedElementType(sourceType.getElementType()))
|
||||
return failure();
|
||||
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (size != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
|
||||
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
|
||||
|
||||
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
|
||||
&& (splitDst || dstShapeMatchesCopyShape)) {
|
||||
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
|
||||
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
outerShape.empty() ? SmallVector<Value> {}
|
||||
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
copyOp.getTargetOffset(),
|
||||
copyOp.getSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getDeviceTarget(),
|
||||
copyOp.getHostSource(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffsetValue,
|
||||
srcOffsetValue,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
||||
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
||||
if (failed(dstOffset) || failed(srcOffset))
|
||||
return failure();
|
||||
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getHostTarget(),
|
||||
copyOp.getDeviceSource(),
|
||||
*dstOffset,
|
||||
*srcOffset,
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/false,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dstOffsetValue,
|
||||
srcOffsetValue,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -47,32 +47,6 @@ struct MemCopyHostToDevOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyHostToDevBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
||||
auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
||||
if (failed(deviceTargetOpt))
|
||||
return failure();
|
||||
auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
||||
if (failed(hostSourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetOpt->getType(),
|
||||
*deviceTargetOpt,
|
||||
*hostSourceOpt,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -151,8 +125,9 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||
rewriter, op, contiguousOutput.getType(), contiguousOutput, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -173,15 +148,16 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
|
||||
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
replaceOpWithNewBufferizedOp<PimConcatOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt);
|
||||
rewriter, op, contiguousOutput.getType(), concatOp.getAxisAttr(), ValueRange(inputs), contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -206,7 +182,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreId());
|
||||
return success();
|
||||
@@ -431,8 +407,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
||||
@@ -475,8 +451,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
|
||||
@@ -514,9 +490,9 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
@@ -547,9 +523,9 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
@@ -583,8 +559,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
|
||||
return success();
|
||||
@@ -599,7 +575,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
|
||||
@@ -6,14 +6,15 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
@@ -40,6 +41,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
|
||||
|
||||
private:
|
||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -84,6 +86,20 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
|
||||
// Later PIM codegen passes may verify this invariant but must not repair it.
|
||||
RewritePatternSet contiguityPatterns(ctx);
|
||||
populatePimContiguityNormalizationPatterns(contiguityPatterns);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (failed(verifyContiguousRuntimeOperands(moduleOp))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||
|
||||
// Dump to file for debug
|
||||
@@ -108,6 +124,75 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
||||
}
|
||||
|
||||
LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp moduleOp) const {
|
||||
bool hasFailure = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
auto verifyOperand = [&](Value operand, unsigned operandIndex) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
return;
|
||||
if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand)))
|
||||
return;
|
||||
op->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage after PIM bufferization";
|
||||
hasFailure = true;
|
||||
};
|
||||
|
||||
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
|
||||
verifyOperand(memCopyOp.getTarget(), 0);
|
||||
verifyOperand(memCopyOp.getSource(), 1);
|
||||
return;
|
||||
}
|
||||
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
||||
verifyOperand(loadOp.getDeviceTarget(), 2);
|
||||
verifyOperand(loadOp.getHostSource(), 3);
|
||||
return;
|
||||
}
|
||||
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
||||
verifyOperand(storeOp.getHostTarget(), 2);
|
||||
verifyOperand(storeOp.getDeviceSource(), 3);
|
||||
return;
|
||||
}
|
||||
if (auto sendOp = dyn_cast<PimSendOp>(op)) {
|
||||
verifyOperand(sendOp.getInput(), 0);
|
||||
return;
|
||||
}
|
||||
if (auto receiveOp = dyn_cast<PimReceiveOp>(op)) {
|
||||
verifyOperand(receiveOp.getOutputBuffer(), 0);
|
||||
return;
|
||||
}
|
||||
if (auto concatOp = dyn_cast<PimConcatOp>(op)) {
|
||||
verifyOperand(concatOp.getOutputBuffer(), 0);
|
||||
for (auto inputAndIndex : llvm::enumerate(concatOp.getInputs()))
|
||||
verifyOperand(inputAndIndex.value(), inputAndIndex.index() + 1);
|
||||
return;
|
||||
}
|
||||
if (isa<PimTransposeOp,
|
||||
PimVMMOp,
|
||||
PimVVAddOp,
|
||||
PimVVSubOp,
|
||||
PimVVMulOp,
|
||||
PimVVMaxOp,
|
||||
PimVVDMulOp,
|
||||
PimVAvgOp,
|
||||
PimVReluOp,
|
||||
PimVTanhOp,
|
||||
PimVSigmOp,
|
||||
PimVSoftmaxOp>(op)) {
|
||||
for (auto operandAndIndex : llvm::enumerate(op->getOperands())) {
|
||||
if (auto vmmOp = dyn_cast<PimVMMOp>(op); vmmOp && operandAndIndex.index() == 0)
|
||||
continue;
|
||||
verifyOperand(operandAndIndex.value(), operandAndIndex.index());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -22,6 +22,10 @@ using namespace onnx_mlir::compact_asm;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// This pass assumes bufferization has already normalized executable PIM
|
||||
// operands. It only reuses compatible local allocations with non-overlapping
|
||||
// lifetimes; it does not repair memory contiguity.
|
||||
|
||||
struct CoalescingReportRow {
|
||||
uint64_t numCandidates = 0;
|
||||
uint64_t numSkipped = 0;
|
||||
|
||||
Reference in New Issue
Block a user