add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled

add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
This commit is contained in:
NiccoloN
2026-06-01 16:49:06 +02:00
parent 356be6ccc2
commit 636310d0cb
55 changed files with 2007 additions and 1103 deletions
+4 -1
View File
@@ -2,7 +2,10 @@ add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/StaticMemoryCoalescing)
add_subdirectory(Transforms/MemoryCoalescing)
add_subdirectory(Transforms/HostConstantFolding)
add_subdirectory(Transforms/HostConstantMaterialization)
add_subdirectory(Transforms/Verification)
add_pim_library(PimOps
PimOps.hpp
@@ -3,6 +3,7 @@
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
@@ -11,24 +12,25 @@ using namespace bufferization;
namespace onnx_mlir::pim {
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
auto sizeInBytes =
getCheckedShapedTypeSizeInBytes(shapedType, contiguousBuffer.getDefiningOp(), "contiguous copy byte size");
if (failed(sizeInBytes))
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
auto sizeAttr =
getCheckedI32Attr(rewriter, contiguousBuffer.getDefiningOp(), *sizeInBytes, "contiguous copy byte size");
if (failed(sizeAttr))
return failure();
return PimMemCopyOp::create(rewriter,
loc,
contiguousType,
zeroOffset,
zeroOffset,
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(sizeInBytes))
return PimMemCopyOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
}
@@ -5,7 +5,8 @@
namespace onnx_mlir::pim {
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
llvm::FailureOr<mlir::Value>
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
@@ -1,10 +1,13 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
return builder.getI32IntegerAttr(sizeInBytes);
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
if (failed(byteSize))
return failure();
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
}
@@ -5,7 +5,8 @@
namespace onnx_mlir {
namespace pim {
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
mlir::FailureOr<mlir::IntegerAttr>
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
} // namespace pim
} // namespace onnx_mlir
@@ -4,8 +4,10 @@
#include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace mlir;
@@ -85,7 +87,13 @@ static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
return failure();
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
auto byteSize =
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "normalized copy byte size");
if (failed(byteSize))
return failure();
if (*byteSize > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
return failure();
return static_cast<int64_t>(*byteSize);
}
static FailureOr<SmallVector<int64_t>>
@@ -325,12 +333,11 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
if (plan->kind == CopyRewritePlan::Kind::Direct) {
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
newTargetOffset,
newSourceOffset,
static_cast<int32_t>(plan->directBytes));
auto checkedDirectBytes = pim::checkedI32(plan->directBytes, anchorOp, "normalized direct copy byte size");
if (failed(checkedDirectBytes))
return failure();
auto newCopyOp =
createCopyOp(loc, plan->target.base, plan->source.base, newTargetOffset, newSourceOffset, *checkedDirectBytes);
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.replaceOp(copyOp, replacementValue);
return success();
@@ -339,23 +346,30 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
Value loopTargetOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
Value loopSourceOffset = materializeOuterByteOffset(
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
auto newCopyOp = createCopyOp(loc,
plan->target.base,
plan->source.base,
loopTargetOffset,
loopSourceOffset,
static_cast<int32_t>(plan->loop.chunkBytes));
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
rewriter.setInsertionPointAfter(loop);
auto loop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cUpper,
cStep,
ValueRange {},
[&](OpBuilder&, Location nestedLoc, Value inductionVar, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
SmallVector<Value> outerIndices =
materializeDelinearizedIndices(rewriter, nestedLoc, anchorOp, inductionVar, plan->loop.outerShape);
Value loopTargetOffset = materializeOuterByteOffset(
rewriter, nestedLoc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
Value loopSourceOffset = materializeOuterByteOffset(
rewriter, nestedLoc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
auto checkedChunkBytes = pim::checkedI32(plan->loop.chunkBytes, anchorOp, "normalized loop copy byte size");
if (failed(checkedChunkBytes))
return failure();
auto newCopyOp = createCopyOp(
nestedLoc, plan->target.base, plan->source.base, loopTargetOffset, loopSourceOffset, *checkedChunkBytes);
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
return success();
});
if (failed(loop))
return failure();
rewriter.replaceOp(copyOp, replacementValue);
return success();
}
@@ -148,7 +148,10 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguous))
return failure();
inputs.push_back(*contiguous);
}
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
@@ -179,12 +182,12 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
op,
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreId());
replaceOpWithNewBufferizedOp<PimSendOp>(
rewriter, op, *contiguousInput, sendOp.getSizeAttr(), sendOp.getTargetCoreId());
return success();
}
};
@@ -407,11 +410,13 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
rewriter, op, contiguousOutput.getType(), *contiguousInput, transposeOp.getPermutation(), contiguousOutput);
return success();
}
};
@@ -451,11 +456,13 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
rewriter, op, contiguousOutput.getType(), *weightOpt, *contiguousInput, contiguousOutput);
return success();
}
};
@@ -490,12 +497,16 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt))
return failure();
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
return success();
}
};
@@ -523,12 +534,16 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt))
return failure();
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
return success();
}
};
@@ -559,10 +574,12 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), *contiguousInput, contiguousOutput);
return success();
}
};
@@ -42,7 +42,9 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
if (failed(sizeAttr))
return failure();
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
copyOp.getTarget().getType(),
@@ -50,7 +52,7 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
zeroOffset,
copyOp.getTarget(),
copyOp.getSource(),
sizeAttr);
*sizeAttr);
rewriter.eraseOp(copyOp);
return success();
}
@@ -0,0 +1,12 @@
add_pim_library(OMPimHostConstantFolding
Common.cpp
Patterns/Constant.cpp
HostConstantFoldingPass.cpp
Patterns/Subview.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
MLIRLinalgDialect
OMPimCommon
)
@@ -0,0 +1,156 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseSubviewKey {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> resultShape;
bool operator==(const DenseSubviewKey& other) const {
return source == other.source && offsets == other.offsets && resultShape == other.resultShape;
}
};
struct DenseSubviewKeyInfo {
static inline DenseSubviewKey getEmptyKey() {
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getEmptyKey()}, {}};
}
static inline DenseSubviewKey getTombstoneKey() {
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getTombstoneKey()}, {}};
}
static unsigned getHashValue(const DenseSubviewKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.source,
llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
}
static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; }
};
} // namespace
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
DenseElementsAttr denseAttr,
StringRef nameStem,
IntegerAttr alignment) {
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment
|| !globalOp.getInitialValue())
continue;
auto existingDenseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (existingDenseAttr == denseAttr)
return globalOp;
}
auto globalName = nameStem.str();
unsigned suffix = 0;
while (moduleOp.lookupSymbol(globalName))
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
return memref::GlobalOp::create(moduleBuilder,
loc,
globalName,
visibility,
globalType,
denseAttr,
/*constant=*/true,
alignment);
}
FailureOr<DenseElementsAttr>
foldDenseSubview(DenseElementsAttr denseAttr, ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> resultShape) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
return failure();
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
DenseSubviewKey key {denseAttr,
SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
if (auto cached = cache.find(key); cached != cache.end())
return cached->second;
auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
if (numResultElements < 0)
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices))
sourceIndices.push_back(offset + index);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
cache.try_emplace(std::move(key), foldedAttr);
return foldedAttr;
}
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
value = stripMemRefCasts(value);
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return failure();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
return denseAttr;
}
FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value source, MemRefType resultType) {
auto srcSubview = getStaticSubviewInfo(source);
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(source);
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
if (succeeded(srcSubview)) {
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
return foldDenseSubview(*denseAttr, *staticOffsets, resultType.getShape());
}
auto resultTensorType = RankedTensorType::get(resultType.getShape(), resultType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
return *denseAttr;
}
} // namespace onnx_mlir
@@ -0,0 +1,31 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
mlir::Location loc,
mlir::MemRefType globalType,
mlir::DenseElementsAttr denseAttr,
llvm::StringRef nameStem,
mlir::IntegerAttr alignment = {});
llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAttr denseAttr,
llvm::ArrayRef<int64_t> staticOffsets,
llvm::ArrayRef<int64_t> resultShape);
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
llvm::FailureOr<mlir::DenseElementsAttr>
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
} // namespace onnx_mlir
@@ -0,0 +1,54 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet owningPatterns(context);
for (auto* dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
populateConstantFoldingConstantPatterns(owningPatterns);
populateConstantFoldingSubviewPatterns(owningPatterns);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim3_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -0,0 +1,546 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantSubviewCopy {
DenseElementsAttr source;
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
Operation* copyOp = nullptr;
};
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides(rank, 1);
SmallVector<int64_t> transposedStrides(rank, 1);
for (int64_t dim = rank - 2; dim >= 0; --dim) {
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
}
SmallVector<int64_t> originalIndices(rank);
SmallVector<int64_t> transposedIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
for (int64_t dim = 0; dim < rank; ++dim)
transposedIndices[dim] = originalIndices[perms[dim]];
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
if (!mapOp.getInputs().empty())
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
Attribute attr;
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
return failure();
return attr;
}
// Folds constant linalg fills inside cores into private globals plus device copies.
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
if (!initType || !initType.hasStaticShape())
return failure();
auto fillValue = getConstantMapYield(mapOp);
if (failed(fillValue))
return failure();
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = pim::getCheckedShapedTypeSizeInBytes(initType, mapOp, "host constant folding byte size");
if (failed(sizeInBytes))
return failure();
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
if (failed(sizeAttr))
return failure();
pim::PimMemCopyOp::create(
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
rewriter.eraseOp(mapOp);
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numElements = resultTensorType.getNumElements();
if (numElements < 0)
return failure();
Attribute fillValue;
SmallVector<ConstantSubviewCopy> copies;
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
SmallVector<Value> pendingAliases;
pendingAliases.push_back(allocOp.getResult());
while (!pendingAliases.empty()) {
Value alias = pendingAliases.pop_back_val();
for (Operation* user : alias.getUsers()) {
if (!visitedAliases.insert(user).second)
continue;
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
if (mapOp.getInit() != alias)
return failure();
auto maybeFillValue = getConstantMapYield(mapOp);
if (failed(maybeFillValue))
return failure();
if (fillValue && fillValue != *maybeFillValue)
return failure();
fillValue = *maybeFillValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
offsets.push_back(*staticOffset);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
strides.push_back(*staticStride);
}
for (Operation* subviewUser : subviewOp->getUsers()) {
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
if (copyOp.getTarget() != subviewOp.getResult())
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
if (failed(denseAttr))
return failure();
copies.push_back({*denseAttr, offsets, strides, copyOp});
continue;
}
return failure();
}
continue;
}
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
continue;
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
pendingAliases.push_back(castOp.getResult());
continue;
}
return failure();
}
}
if (!fillValue)
return failure();
SmallVector<Attribute> resultValues(numElements, fillValue);
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
});
for (const ConstantSubviewCopy& copy : copies) {
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
SmallVector<int64_t> sourceIndices =
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
SmallVector<int64_t> resultIndices;
resultIndices.reserve(sourceIndices.size());
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
resultIndices.push_back(offset + sourceIndex * stride);
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
resultValues[resultLinearIndex] = value;
}
}
return DenseElementsAttr::get(resultTensorType, resultValues);
}
// Folds transposes of constant globals so weight-only transposes stay host-side.
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
// Look through an optional pim.memcp_hd to find the source get_global.
// This occurs when the constant was staged into device memory before transposing.
pim::PimMemCopyHostToDevOp memcpHd;
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal) {
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
if (!memcpHd)
return failure();
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
if (!sourceGetGlobal)
return failure();
}
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
if (!denseAttr)
return failure();
SmallVector<int64_t> perms;
perms.reserve(transposeOp.getPermutation().size());
for (IntegerAttr attr : transposeOp.getPermutation().getAsRange<IntegerAttr>())
perms.push_back(attr.getInt());
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
if (failed(transposedAttr))
return failure();
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
if (!llvm::equal(transposedShape, resultType.getShape()))
return failure();
auto newGlobal = createFoldedGlobal(moduleOp,
transposeOp.getLoc(),
resultType,
*transposedAttr,
sourceGlobal.getName().str() + "__folded_transpose",
sourceGlobal.getAlignmentAttr());
rewriter.setInsertionPoint(transposeOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
bool isAlwaysWeight = !transposeOp->getUsers().empty()
&& llvm::all_of(transposeOp->getUsers(),
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
if (isAlwaysWeight) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
if (memcpHd && memcpHd.use_empty()) {
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
rewriter.eraseOp(memcpHd);
if (deviceAllocOp && deviceAllocOp->use_empty())
rewriter.eraseOp(deviceAllocOp);
}
if (outputAllocOp && outputAllocOp->use_empty())
rewriter.eraseOp(outputAllocOp);
return success();
}
};
// Collapses fill-and-copy allocation chains into one folded constant global.
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
if (failed(foldedAttr))
return failure();
auto allocType = cast<MemRefType>(allocOp.getType());
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
SmallVector<Operation*> opsToErase;
SmallVector<memref::CastOp> castsToReplace;
bool allLiveUsersAreCoreOps = true;
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
opsToErase.push_back(user);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
castsToReplace.push_back(castOp);
continue;
}
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
return failure();
}
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
return llvm::all_of(castOp->getUsers(),
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
})) {
allLiveUsersAreCoreOps = false;
}
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
return isa<linalg::MapOp,
memref::SubViewOp,
memref::DeallocOp,
memref::CastOp,
pim::PimCoreOp,
pim::PimCoreBatchOp>(user);
})) {
return failure();
}
if (allLiveUsersAreCoreOps) {
markWeightAlways(newGlobal);
markWeightAlways(newGetGlobal);
}
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
for (memref::CastOp castOp : castsToReplace)
preservedUsers.insert(castOp);
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
for (memref::CastOp castOp : castsToReplace) {
rewriter.setInsertionPoint(castOp);
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
rewriter.replaceOp(castOp, replacementCast);
if (allLiveUsersAreCoreOps)
markWeightAlways(replacementCast.getDefiningOp());
}
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
rewriter.eraseOp(subviewUser);
if (op->use_empty())
rewriter.eraseOp(op);
}
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
// Converts host copies from dense globals into direct folded globals.
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(foldedAttr))
return failure();
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
// Converts PIM copies from dense globals into direct folded globals before codegen.
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
auto targetOffset = resolveIndexValue(copyOp.getTargetOffset());
auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset());
if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0)
return failure();
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
if (failed(foldedAttr))
return failure();
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
} // namespace
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
FoldConstantHostCopyPattern,
FoldConstantMemCpPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,61 @@
#include "../Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
// Folds constant subviews used as core weights into standalone globals.
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
if (subviewOp.use_empty())
return failure();
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
return failure();
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
if (failed(denseAttr))
return failure();
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
if (failed(subviewInfo))
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
if (failed(staticOffsets))
return failure();
auto resultMemRefType = cast<MemRefType>(subviewOp.getType());
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
if (failed(foldedAttr))
return failure();
auto newGlobal =
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal =
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
return success();
}
};
} // namespace
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConstantCoreSubviewPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
add_pim_library(OMPimHostConstantMaterialization
MaterializeHostConstantsPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -0,0 +1,161 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
DominanceInfo dominance(coreOp);
SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) {
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
ops.push_back(op);
});
for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
continue;
auto resolvedAddress = resolveContiguousAddress(originalValue);
if (failed(resolvedAddress))
continue;
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
if (!getGlobalOp)
continue;
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
if (!originalType || !originalType.hasStaticShape()) {
op->emitOpError("host constant materialization requires a static memref operand");
hasFailure = true;
continue;
}
auto& cachedByOffset = materializedValues[resolvedAddress->base];
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
auto cachedValue = cachedByType.find(originalType);
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
operand.set(cachedValue->second);
continue;
}
auto type = dyn_cast<ShapedType>(originalValue.getType());
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
: FailureOr<uint64_t>(failure());
auto totalBytesAttr =
succeeded(totalBytes)
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
: FailureOr<IntegerAttr>(failure());
if (failed(totalBytesAttr)
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
hasFailure = true;
continue;
}
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
rewriter.setInsertionPoint(op);
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
Value deviceDst = localAlloc;
if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
originalType,
zeroOffset,
hostOffset,
deviceDst,
getGlobalOp.getResult(),
*totalBytesAttr)
.getOutput();
cachedByType[originalType] = copiedValue;
operand.set(copiedValue);
}
}
}
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front())
if (isa<pim::PimConcatOp>(op))
hostCompactOps.push_back(&op);
for (Operation* op : hostCompactOps) {
rewriter.setInsertionPoint(op);
auto concatOp = cast<pim::PimConcatOp>(op);
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
hasFailure = true;
}
}
if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure();
return;
}
dumpModule(moduleOp, "pim4_materialized");
}
};
} // namespace
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir
@@ -0,0 +1,14 @@
add_pim_library(OMPimMemoryCoalescing
MemoryCoalescing.cpp
MemoryCoalescing.hpp
MemoryCoalescingPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -10,7 +10,7 @@
#include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
using namespace mlir;
@@ -32,6 +32,13 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
}
static Operation* getTopLevelAncestorInBody(Operation* op, Block& body) {
Operation* current = op;
while (current && current->getBlock() != &body)
current = current->getParentOp();
return current;
}
static FailureOr<uint64_t>
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp);
@@ -42,7 +49,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
while (!pendingValues.empty()) {
Value value = pendingValues.pop_back_val();
for (Operation* user : value.getUsers()) {
if (user->getBlock() != &body)
Operation* orderedUser = getTopLevelAncestorInBody(user, body);
if (!orderedUser)
return failure();
if (!visited.insert(user).second)
continue;
@@ -51,6 +59,15 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
for (Value result : user->getResults())
pendingValues.push_back(result);
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return failure();
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands()))
if (operand == value)
pendingValues.push_back(forOp.getResult(index));
}
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value)
@@ -66,7 +83,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
}
}
auto order = opOrder.find(user);
auto order = opOrder.find(orderedUser);
if (order == opOrder.end())
return failure();
endInstruction = std::max(endInstruction, order->second);
@@ -78,8 +95,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
} // namespace
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation* coreLikeOp) {
StaticMemoryCoalescingAnalysis analysis;
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation* coreLikeOp) {
MemoryCoalescingAnalysis analysis;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return analysis;
@@ -107,18 +124,19 @@ StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation
}
analysis.candidates.push_back(
StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
AllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
}
return analysis;
}
StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) {
StaticMemoryCoalescingStats stats;
auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp);
MemoryCoalescingStats
coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) {
MemoryCoalescingStats stats;
stats.skippedAllocations = analysis.skippedAllocations;
llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) {
auto candidates = analysis.candidates;
llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) {
if (lhs.startInstruction != rhs.startInstruction)
return lhs.startInstruction < rhs.startInstruction;
return lhs.endInstruction < rhs.endInstruction;
@@ -132,7 +150,7 @@ StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, Rewriter
SmallVector<ActiveStorage> active;
SmallVector<memref::AllocOp> freeList;
for (StaticAllocationCandidate& candidate : analysis.candidates) {
for (AllocationCandidate& candidate : candidates) {
for (auto it = active.begin(); it != active.end();) {
if (it->endInstruction < candidate.startInstruction) {
freeList.push_back(it->root);
@@ -8,27 +8,28 @@
namespace onnx_mlir {
namespace pim {
struct StaticAllocationCandidate {
struct AllocationCandidate {
mlir::memref::AllocOp alloc;
uint64_t startInstruction = 0;
uint64_t endInstruction = 0;
uint64_t sizeBytes = 0;
};
struct StaticMemoryCoalescingAnalysis {
llvm::SmallVector<StaticAllocationCandidate> candidates;
struct MemoryCoalescingAnalysis {
llvm::SmallVector<AllocationCandidate> candidates;
uint64_t skippedAllocations = 0;
};
struct StaticMemoryCoalescingStats {
struct MemoryCoalescingStats {
uint64_t removedAllocs = 0;
uint64_t savedBytes = 0;
uint64_t skippedAllocations = 0;
};
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter);
MemoryCoalescingStats
coalesceMemory(mlir::Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, mlir::RewriterBase& rewriter);
} // namespace pim
} // namespace onnx_mlir
@@ -10,10 +10,11 @@
#include "Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
@@ -151,34 +152,39 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
file.close();
}
struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass)
struct PimMemoryCoalescingPass : PassWrapper<PimMemoryCoalescingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMemoryCoalescingPass)
StringRef getArgument() const override { return "pim-static-memory-coalescing"; }
StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; }
StringRef getArgument() const override { return "pim-memory-coalescing"; }
StringRef getDescription() const override { return "Analyze local PIM memory reuse opportunities"; }
StaticMemoryCoalescingPass() = default;
StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {}
PimMemoryCoalescingPass() = default;
PimMemoryCoalescingPass(const PimMemoryCoalescingPass& pass) {}
void runOnOperation() override {
IRRewriter rewriter(&getContext());
SmallVector<CoalescingReportEntry, 32> reportEntries;
uint64_t nextBatchId = 0;
bool hasFailure = false;
getOperation().walk([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
if (hasFailure || !isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return;
auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op);
auto stats = pim::coalesceStaticMemory(op, rewriter);
auto analysis = pim::analyzeMemoryCoalescingCandidates(op);
auto stats = pim::coalesceMemory(op, analysis, rewriter);
CoalescingReportRow row {
analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
reportEntries.push_back({CoalescingReportEntry::Kind::Core,
static_cast<uint64_t>(coreOp.getCoreId()),
{static_cast<int32_t>(coreOp.getCoreId())},
row});
auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(coreOp.getCoreId()), coreOp, "memory coalescing core id");
if (failed(checkedCoreId)) {
hasFailure = true;
return;
}
reportEntries.push_back(
{CoalescingReportEntry::Kind::Core, static_cast<uint64_t>(coreOp.getCoreId()), {*checkedCoreId}, row});
return;
}
@@ -191,6 +197,11 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
reportEntries.push_back(std::move(entry));
});
if (hasFailure) {
signalPassFailure();
return;
}
emitReport(reportEntries);
dumpModule(getOperation(), "pim2_coalesced");
}
@@ -198,6 +209,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
} // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
std::unique_ptr<Pass> createPimMemoryCoalescingPass() { return std::make_unique<PimMemoryCoalescingPass>(); }
} // namespace onnx_mlir
@@ -1,14 +0,0 @@
add_pim_library(OMPimStaticMemoryCoalescing
StaticMemoryCoalescing.cpp
StaticMemoryCoalescing.hpp
StaticMemoryCoalescingPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -0,0 +1,11 @@
add_pim_library(OMPimVerification
VerificationPass.cpp
EXCLUDE_FROM_OM_LIBS
LINK_LIBS PUBLIC
OMPimCommon
OMPimBufferization
PimOps
SpatialOps
)
@@ -0,0 +1,434 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isAddressOnlyHostOp(Operation* op) {
return isa<arith::ConstantOp,
memref::AllocOp,
memref::GetGlobalOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp,
memref::CopyOp>(op);
}
static bool isCodegenAddressableValue(Value value) {
auto resolvedAddress = resolveContiguousAddress(value);
if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false;
return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
}
static bool isConstantGlobalView(Value value) {
while (true) {
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)) {
auto moduleOp = getGlobalOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant() && globalOp.getInitialValue()
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
}
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return false;
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return false;
value = expand.getSrc();
continue;
}
return false;
}
}
static bool isCoreWeightBlockArgument(Value value) {
auto blockArgument = dyn_cast<BlockArgument>(value);
if (!blockArgument)
return false;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(blockArgument.getOwner()->getParentOp()))
return static_cast<unsigned>(blockArgument.getArgNumber()) < coreOp.getWeights().size();
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(blockArgument.getOwner()->getParentOp())) {
unsigned argNumber = static_cast<unsigned>(blockArgument.getArgNumber());
return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size();
}
return false;
}
static bool isSupportedCoreInstructionOp(Operation* op) {
return isa<pim::PimMemCopyHostToDevOp,
pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp,
pim::PimReceiveOp,
pim::PimSendOp,
pim::PimConcatOp,
pim::PimVMMOp,
pim::PimTransposeOp,
pim::PimVVAddOp,
pim::PimVVSubOp,
pim::PimVVMulOp,
pim::PimVVMaxOp,
pim::PimVVDMulOp,
pim::PimVAvgOp,
pim::PimVReluOp,
pim::PimVTanhOp,
pim::PimVSigmOp,
pim::PimVSoftmaxOp,
memref::GetGlobalOp>(op);
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
VerificationPass() {}
VerificationPass(const VerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
pim::CappedDiagnosticReporter diagnostics;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
});
});
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
StaticValueKnowledge knowledge;
(void) verifyCoreLikeOperands(coreOp, knowledge, diagnostics);
continue;
}
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
llvm::SmallVector<unsigned, 2> lanes;
lanes.push_back(0);
if (coreBatchOp.getLaneCount() > 1)
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
for (unsigned lane : lanes) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
(void) verifyCoreLikeOperands(coreBatchOp, knowledge, diagnostics);
}
continue;
}
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
(void) verifyReturnOp(returnOp, diagnostics);
continue;
}
if (!isAddressOnlyHostOp(&op)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
"fold it to constants or lower it into pim.core");
});
continue;
}
(void) verifyAddressOnlyHostOp(&op, diagnostics);
}
}
if (diagnostics.hasFailure()) {
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure();
}
}
private:
template <typename CoreOpTy>
static LogicalResult
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
for (auto it : llvm::enumerate(coreOp.getWeights())) {
size_t weightIndex = it.index();
Value weight = it.value();
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp && !isConstantGlobalView(weight)) {
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must be materialized as a constant memref.global or a static view of one before "
"JSON codegen";
});
hasFailure = true;
continue;
}
if (!getGlobalOp)
continue;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
});
hasFailure = true;
continue;
}
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
coreOp.emitOpError() << "weight #" << weightIndex
<< " must come from a constant memref.global with an initial value";
});
hasFailure = true;
}
}
return success(!hasFailure);
}
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
for (auto it : llvm::enumerate(returnOp.getOperands())) {
size_t resultIndex = it.index();
Value operand = it.value();
if (!isCodegenAddressableValue(operand)) {
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
});
hasFailure = true;
}
}
return success(!hasFailure);
}
template <typename CoreLikeOpTy>
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
const StaticValueKnowledge& initialKnowledge,
pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlockStructurally(
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false;
if (!isSupportedCoreInstructionOp(&op)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
});
hasFailure = true;
}
for (auto it : llvm::enumerate(op.getOperands())) {
size_t operandIndex = it.index();
Value operand = it.value();
if (!isa<BaseMemRefType>(operand.getType()))
continue;
if (isCoreWeightBlockArgument(operand))
continue;
if (auto vmmOp = dyn_cast<pim::PimVMMOp>(&op);
vmmOp && operandIndex == 0 && resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight())) {
continue;
}
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress) && failed(compileContiguousAddressExpr(operand))) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true;
continue;
}
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand, knowledge)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
});
hasFailure = true;
}
continue;
}
Value addressBase =
succeeded(resolvedAddress) ? resolvedAddress->base : compileContiguousAddressExpr(operand)->base;
if (!isa<memref::AllocOp>(addressBase.getDefiningOp())) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with "
"pim.memcp_hd";
});
hasFailure = true;
}
}
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
if (!pim::isNormalizedCopyOp(storeOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
if (!pim::isNormalizedCopyOp(loadOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op)) {
if (!pim::isNormalizedCopyOp(copyOp)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
});
hasFailure = true;
}
if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge))
|| failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
return success(!hasFailure);
});
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
if (auto castOp = dyn_cast<memref::CastOp>(op))
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
if (!isMemRefBaseAddressableValue(copyOp.getSource()) || !isMemRefBaseAddressableValue(copyOp.getTarget())) {
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure();
}
return success();
}
return success();
}
static LogicalResult
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isCodegenAddressableValue(source))
return success();
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
});
return failure();
}
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isMemRefBaseAddressableValue(source))
return success();
diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
});
return failure();
}
};
} // namespace
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -37,6 +37,7 @@
#include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -128,8 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId))
return std::nullopt;
return *checkedCoreId;
}
return std::nullopt;
}