reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -13,6 +13,7 @@ add_pim_library(OMPimPasses
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
OMCompilerUtils
|
||||
OMPimCommon
|
||||
)
|
||||
|
||||
@@ -85,12 +85,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
@@ -106,14 +102,16 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
return info;
|
||||
}
|
||||
|
||||
int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace onnx_mlir {
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
@@ -34,8 +34,7 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
llvm::ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth);
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -120,7 +120,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
@@ -416,6 +424,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
@@ -428,7 +439,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -8,6 +10,62 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
|
||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
|
||||
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
if (firstDifferentSize == sizesAndShape.end())
|
||||
return true;
|
||||
|
||||
++firstDifferentSize;
|
||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size == 1;
|
||||
});
|
||||
}
|
||||
|
||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||
if (extraOffset == 0)
|
||||
return baseOffset;
|
||||
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(integerAttr && "expected integer offset attribute");
|
||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||
}
|
||||
|
||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
@@ -19,12 +77,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc =
|
||||
succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst =
|
||||
succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
|
||||
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
@@ -35,9 +89,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
@@ -64,18 +118,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset =
|
||||
srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset =
|
||||
dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : dst,
|
||||
splitSrc ? srcSubview->source : src,
|
||||
dstByteOffset,
|
||||
srcByteOffset,
|
||||
sliceBytes);
|
||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
@@ -198,6 +245,9 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
@@ -217,7 +267,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
|
||||
@@ -132,38 +132,37 @@ private:
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : coreOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
return success(!hasFailure);
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
|
||||
Reference in New Issue
Block a user