reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-04-10 18:50:25 +02:00
parent f3a36e9d43
commit f054e66ed0
18 changed files with 623 additions and 241 deletions

View File

@@ -13,6 +13,7 @@ add_pim_library(OMPimPasses
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
OMCompilerUtils
OMPimCommon
)

View File

@@ -85,12 +85,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
@@ -106,14 +102,16 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
return info;
}
int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(info.offsets.size());
for (OpFoldResult offset : info.offsets) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
staticOffsets.push_back(*staticOffset);
}
return staticOffsets;
}
} // namespace onnx_mlir

View File

@@ -14,7 +14,7 @@ namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
@@ -34,8 +34,7 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
llvm::ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
} // namespace onnx_mlir

View File

@@ -120,7 +120,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp);
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(mapOp);
return success();
}
@@ -416,6 +424,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
@@ -428,7 +439,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];

View File

@@ -1,3 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -8,6 +10,62 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize == sizesAndShape.end())
return true;
++firstDifferentSize;
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
auto [size, _dimension] = sizeAndShape;
return size == 1;
});
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
}
auto value = cast<Value>(baseOffset);
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
@@ -19,12 +77,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc =
succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst =
succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
return failure();
@@ -35,9 +89,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
@@ -64,18 +118,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset =
srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
const int64_t dstByteOffset =
dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
dstByteOffset,
srcByteOffset,
sliceBytes);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
return success();
@@ -198,6 +245,9 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
if (failed(staticOffsets))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
@@ -217,7 +267,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}

View File

@@ -132,38 +132,37 @@ private:
}
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (Operation& op : coreOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op))
continue;
return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
return success(!hasFailure);
});
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {