add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -2,7 +2,10 @@ add_onnx_mlir_dialect(Pim pim)
|
||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
add_subdirectory(Transforms/StaticMemoryCoalescing)
|
||||
add_subdirectory(Transforms/MemoryCoalescing)
|
||||
add_subdirectory(Transforms/HostConstantFolding)
|
||||
add_subdirectory(Transforms/HostConstantMaterialization)
|
||||
add_subdirectory(Transforms/Verification)
|
||||
|
||||
add_pim_library(PimOps
|
||||
PimOps.hpp
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||
|
||||
@@ -11,24 +12,25 @@ using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||
auto sizeInBytes =
|
||||
getCheckedShapedTypeSizeInBytes(shapedType, contiguousBuffer.getDefiningOp(), "contiguous copy byte size");
|
||||
if (failed(sizeInBytes))
|
||||
return failure();
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, contiguousBuffer.getDefiningOp(), 0);
|
||||
auto sizeAttr =
|
||||
getCheckedI32Attr(rewriter, contiguousBuffer.getDefiningOp(), *sizeInBytes, "contiguous copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousType,
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
return PimMemCopyOp::create(
|
||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
llvm::FailureOr<mlir::Value>
|
||||
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
mlir::Value
|
||||
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||
FailureOr<IntegerAttr> onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Operation* anchor, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
|
||||
return builder.getI32IntegerAttr(sizeInBytes);
|
||||
auto byteSize = getCheckedShapedTypeSizeInBytes(type, anchor, "memref byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
return getCheckedI32Attr(builder, anchor, *byteSize, "memref byte size");
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
mlir::IntegerAttr getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Value memref);
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getMemRefSizeInBytesAttr(mlir::OpBuilder& builder, mlir::Operation* anchor, mlir::Value memref);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
|
||||
#include "ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -85,7 +87,13 @@ static FailureOr<SmallVector<int64_t>> getStaticMemRefStrides(MemRefType type) {
|
||||
static FailureOr<int64_t> getShapedByteSize(MemRefType type) {
|
||||
if (!type.hasStaticShape() || !hasByteSizedElementType(type.getElementType()))
|
||||
return failure();
|
||||
return static_cast<int64_t>(getShapedTypeSizeInBytes(type));
|
||||
auto byteSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "normalized copy byte size");
|
||||
if (failed(byteSize))
|
||||
return failure();
|
||||
if (*byteSize > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
|
||||
return failure();
|
||||
return static_cast<int64_t>(*byteSize);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>>
|
||||
@@ -325,12 +333,11 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
|
||||
if (plan->kind == CopyRewritePlan::Kind::Direct) {
|
||||
Value newTargetOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->target.offset);
|
||||
Value newSourceOffset = materializeByteOffset(rewriter, loc, anchorOp, plan->source.offset);
|
||||
auto newCopyOp = createCopyOp(loc,
|
||||
plan->target.base,
|
||||
plan->source.base,
|
||||
newTargetOffset,
|
||||
newSourceOffset,
|
||||
static_cast<int32_t>(plan->directBytes));
|
||||
auto checkedDirectBytes = pim::checkedI32(plan->directBytes, anchorOp, "normalized direct copy byte size");
|
||||
if (failed(checkedDirectBytes))
|
||||
return failure();
|
||||
auto newCopyOp =
|
||||
createCopyOp(loc, plan->target.base, plan->source.base, newTargetOffset, newSourceOffset, *checkedDirectBytes);
|
||||
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||
rewriter.replaceOp(copyOp, replacementValue);
|
||||
return success();
|
||||
@@ -339,23 +346,30 @@ static LogicalResult rewriteCopyLikeOp(CopyOp copyOp,
|
||||
Value c0 = createIndexConstant(rewriter, anchorOp, 0);
|
||||
Value cUpper = createIndexConstant(rewriter, anchorOp, getNumElements(plan->loop.outerShape));
|
||||
Value cStep = createIndexConstant(rewriter, anchorOp, 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
materializeDelinearizedIndices(rewriter, loc, anchorOp, loop.getInductionVar(), plan->loop.outerShape);
|
||||
Value loopTargetOffset = materializeOuterByteOffset(
|
||||
rewriter, loc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
|
||||
Value loopSourceOffset = materializeOuterByteOffset(
|
||||
rewriter, loc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
|
||||
auto newCopyOp = createCopyOp(loc,
|
||||
plan->target.base,
|
||||
plan->source.base,
|
||||
loopTargetOffset,
|
||||
loopSourceOffset,
|
||||
static_cast<int32_t>(plan->loop.chunkBytes));
|
||||
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
auto loop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
loc,
|
||||
c0,
|
||||
cUpper,
|
||||
cStep,
|
||||
ValueRange {},
|
||||
[&](OpBuilder&, Location nestedLoc, Value inductionVar, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
SmallVector<Value> outerIndices =
|
||||
materializeDelinearizedIndices(rewriter, nestedLoc, anchorOp, inductionVar, plan->loop.outerShape);
|
||||
Value loopTargetOffset = materializeOuterByteOffset(
|
||||
rewriter, nestedLoc, anchorOp, plan->loop.targetBaseOffset, outerIndices, plan->loop.targetOuterByteStrides);
|
||||
Value loopSourceOffset = materializeOuterByteOffset(
|
||||
rewriter, nestedLoc, anchorOp, plan->loop.sourceBaseOffset, outerIndices, plan->loop.sourceOuterByteStrides);
|
||||
auto checkedChunkBytes = pim::checkedI32(plan->loop.chunkBytes, anchorOp, "normalized loop copy byte size");
|
||||
if (failed(checkedChunkBytes))
|
||||
return failure();
|
||||
auto newCopyOp = createCopyOp(
|
||||
nestedLoc, plan->target.base, plan->source.base, loopTargetOffset, loopSourceOffset, *checkedChunkBytes);
|
||||
assert(isNormalizedCopyOp(newCopyOp) && "copy normalization emitted a non-normalized copy");
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
rewriter.replaceOp(copyOp, replacementValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -148,7 +148,10 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
|
||||
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguous))
|
||||
return failure();
|
||||
inputs.push_back(*contiguous);
|
||||
}
|
||||
|
||||
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
|
||||
@@ -179,12 +182,12 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousInput))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
|
||||
op,
|
||||
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreId());
|
||||
replaceOpWithNewBufferizedOp<PimSendOp>(
|
||||
rewriter, op, *contiguousInput, sendOp.getSizeAttr(), sendOp.getTargetCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -407,11 +410,13 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousInput))
|
||||
return failure();
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
||||
rewriter, op, contiguousOutput.getType(), *contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -451,11 +456,13 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousInput))
|
||||
return failure();
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
|
||||
rewriter, op, contiguousOutput.getType(), *weightOpt, *contiguousInput, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -490,12 +497,16 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousLhs))
|
||||
return failure();
|
||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousRhs))
|
||||
return failure();
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -523,12 +534,16 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousLhs))
|
||||
return failure();
|
||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousRhs))
|
||||
return failure();
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||
rewriter, op, contiguousOutput.getType(), *contiguousLhs, *contiguousRhs, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -559,10 +574,12 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
if (failed(contiguousInput))
|
||||
return failure();
|
||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
|
||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), *contiguousInput, contiguousOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -42,7 +42,9 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
return failure();
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||
IntegerAttr sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getSource());
|
||||
auto sizeAttr = getMemRefSizeInBytesAttr(rewriter, copyOp.getOperation(), copyOp.getSource());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
copyOp.getTarget().getType(),
|
||||
@@ -50,7 +52,7 @@ struct MemRefCopyToPimMemCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
zeroOffset,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
sizeAttr);
|
||||
*sizeAttr);
|
||||
rewriter.eraseOp(copyOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
add_pim_library(OMPimHostConstantFolding
|
||||
Common.cpp
|
||||
Patterns/Constant.cpp
|
||||
HostConstantFoldingPass.cpp
|
||||
Patterns/Subview.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
OMPimCommon
|
||||
)
|
||||
@@ -0,0 +1,156 @@
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct DenseSubviewKey {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> resultShape;
|
||||
|
||||
bool operator==(const DenseSubviewKey& other) const {
|
||||
return source == other.source && offsets == other.offsets && resultShape == other.resultShape;
|
||||
}
|
||||
};
|
||||
|
||||
struct DenseSubviewKeyInfo {
|
||||
static inline DenseSubviewKey getEmptyKey() {
|
||||
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getEmptyKey()}, {}};
|
||||
}
|
||||
|
||||
static inline DenseSubviewKey getTombstoneKey() {
|
||||
return {DenseElementsAttr(), {DenseMapInfo<int64_t>::getTombstoneKey()}, {}};
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const DenseSubviewKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.source,
|
||||
llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()),
|
||||
llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end())));
|
||||
}
|
||||
|
||||
static bool isEqual(const DenseSubviewKey& lhs, const DenseSubviewKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment) {
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != globalType || globalOp.getAlignmentAttr() != alignment
|
||||
|| !globalOp.getInitialValue())
|
||||
continue;
|
||||
|
||||
auto existingDenseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (existingDenseAttr == denseAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr>
|
||||
foldDenseSubview(DenseElementsAttr denseAttr, ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> resultShape) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast<int64_t>(staticOffsets.size())
|
||||
|| sourceType.getRank() != static_cast<int64_t>(resultShape.size()))
|
||||
return failure();
|
||||
|
||||
static DenseMap<DenseSubviewKey, DenseElementsAttr, DenseSubviewKeyInfo> cache;
|
||||
DenseSubviewKey key {denseAttr,
|
||||
SmallVector<int64_t>(staticOffsets.begin(), staticOffsets.end()),
|
||||
SmallVector<int64_t>(resultShape.begin(), resultShape.end())};
|
||||
if (auto cached = cache.find(key); cached != cache.end())
|
||||
return cached->second;
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(resultShape, sourceType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
if (numResultElements < 0)
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [offset, index] : llvm::zip_equal(staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(offset + index);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
cache.try_emplace(std::move(key), foldedAttr);
|
||||
return foldedAttr;
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value source, MemRefType resultType) {
|
||||
auto srcSubview = getStaticSubviewInfo(source);
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(source);
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
if (succeeded(srcSubview)) {
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
return foldDenseSubview(*denseAttr, *staticOffsets, resultType.getShape());
|
||||
}
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(resultType.getShape(), resultType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
return *denseAttr;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
mlir::Location loc,
|
||||
mlir::MemRefType globalType,
|
||||
mlir::DenseElementsAttr denseAttr,
|
||||
llvm::StringRef nameStem,
|
||||
mlir::IntegerAttr alignment = {});
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> foldDenseSubview(mlir::DenseElementsAttr denseAttr,
|
||||
llvm::ArrayRef<int64_t> staticOffsets,
|
||||
llvm::ArrayRef<int64_t> resultShape);
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr>
|
||||
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,54 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
|
||||
populateConstantFoldingConstantPatterns(owningPatterns);
|
||||
populateConstantFoldingSubviewPatterns(owningPatterns);
|
||||
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
|
||||
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim3_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,546 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
|
||||
SmallVector<int64_t> originalStrides(rank, 1);
|
||||
SmallVector<int64_t> transposedStrides(rank, 1);
|
||||
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
||||
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
||||
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
||||
}
|
||||
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
SmallVector<int64_t> transposedIndices(rank);
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedIndices[dim] = originalIndices[perms[dim]];
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
if (!mapOp.getInputs().empty())
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(mapOp.getMapper().front().getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return failure();
|
||||
|
||||
Attribute attr;
|
||||
if (!matchPattern(yieldOp.getValues().front(), m_Constant(&attr)))
|
||||
return failure();
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Folds constant linalg fills inside cores into private globals plus device copies.
|
||||
struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return failure();
|
||||
|
||||
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||
if (!initType || !initType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto fillValue = getConstantMapYield(mapOp);
|
||||
if (failed(fillValue))
|
||||
return failure();
|
||||
|
||||
auto tensorType = RankedTensorType::get(initType.getShape(), initType.getElementType());
|
||||
DenseElementsAttr splatAttr = DenseElementsAttr::get(tensorType, *fillValue);
|
||||
|
||||
auto moduleOp = mapOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
auto sizeInBytes = pim::getCheckedShapedTypeSizeInBytes(initType, mapOp, "host constant folding byte size");
|
||||
if (failed(sizeInBytes))
|
||||
return failure();
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, mapOp, 0);
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, mapOp, *sizeInBytes, "host constant folding byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimMemCopyOp::create(
|
||||
rewriter, mapOp.getLoc(), initType, zeroOffset, zeroOffset, mapOp.getInit(), getGlobalOp.getResult(), *sizeAttr);
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numElements = resultTensorType.getNumElements();
|
||||
if (numElements < 0)
|
||||
return failure();
|
||||
|
||||
Attribute fillValue;
|
||||
SmallVector<ConstantSubviewCopy> copies;
|
||||
llvm::SmallPtrSet<Operation*, 8> visitedAliases;
|
||||
SmallVector<Value> pendingAliases;
|
||||
pendingAliases.push_back(allocOp.getResult());
|
||||
|
||||
while (!pendingAliases.empty()) {
|
||||
Value alias = pendingAliases.pop_back_val();
|
||||
for (Operation* user : alias.getUsers()) {
|
||||
if (!visitedAliases.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (auto mapOp = dyn_cast<linalg::MapOp>(user)) {
|
||||
if (mapOp.getInit() != alias)
|
||||
return failure();
|
||||
auto maybeFillValue = getConstantMapYield(mapOp);
|
||||
if (failed(maybeFillValue))
|
||||
return failure();
|
||||
if (fillValue && fillValue != *maybeFillValue)
|
||||
return failure();
|
||||
fillValue = *maybeFillValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
strides.push_back(*staticStride);
|
||||
}
|
||||
|
||||
for (Operation* subviewUser : subviewOp->getUsers()) {
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(subviewUser)) {
|
||||
if (copyOp.getTarget() != subviewOp.getResult())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, copyOp.getSource());
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
copies.push_back({*denseAttr, offsets, strides, copyOp});
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<pim::PimCoreOp, memref::DeallocOp>(user))
|
||||
continue;
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||
pendingAliases.push_back(castOp.getResult());
|
||||
continue;
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
if (!fillValue)
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute> resultValues(numElements, fillValue);
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
|
||||
llvm::sort(copies, [](const ConstantSubviewCopy& lhs, const ConstantSubviewCopy& rhs) {
|
||||
return lhs.copyOp->isBeforeInBlock(rhs.copyOp);
|
||||
});
|
||||
|
||||
for (const ConstantSubviewCopy& copy : copies) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(copy.source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getRank() != static_cast<int64_t>(copy.offsets.size())
|
||||
|| sourceType.getRank() != static_cast<int64_t>(copy.strides.size()))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<Attribute> sourceValues(copy.source.getValues<Attribute>());
|
||||
for (auto [linearIndex, value] : llvm::enumerate(sourceValues)) {
|
||||
SmallVector<int64_t> sourceIndices =
|
||||
delinearizeIndex(static_cast<int64_t>(linearIndex), sourceType.getShape(), sourceStrides);
|
||||
SmallVector<int64_t> resultIndices;
|
||||
resultIndices.reserve(sourceIndices.size());
|
||||
for (auto [offset, sourceIndex, stride] : llvm::zip_equal(copy.offsets, sourceIndices, copy.strides))
|
||||
resultIndices.push_back(offset + sourceIndex * stride);
|
||||
|
||||
int64_t resultLinearIndex = linearizeIndex(resultIndices, resultStrides);
|
||||
resultValues[resultLinearIndex] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
|
||||
// Folds transposes of constant globals so weight-only transposes stay host-side.
|
||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Look through an optional pim.memcp_hd to find the source get_global.
|
||||
// This occurs when the constant was staged into device memory before transposing.
|
||||
pim::PimMemCopyHostToDevOp memcpHd;
|
||||
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal) {
|
||||
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
|
||||
if (!memcpHd)
|
||||
return failure();
|
||||
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
||||
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> perms;
|
||||
perms.reserve(transposeOp.getPermutation().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermutation().getAsRange<IntegerAttr>())
|
||||
perms.push_back(attr.getInt());
|
||||
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||
if (failed(transposedAttr))
|
||||
return failure();
|
||||
|
||||
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||
transposeOp.getLoc(),
|
||||
resultType,
|
||||
*transposedAttr,
|
||||
sourceGlobal.getName().str() + "__folded_transpose",
|
||||
sourceGlobal.getAlignmentAttr());
|
||||
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight = !transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
|
||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||
|
||||
if (memcpHd && memcpHd.use_empty()) {
|
||||
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
|
||||
rewriter.eraseOp(memcpHd);
|
||||
if (deviceAllocOp && deviceAllocOp->use_empty())
|
||||
rewriter.eraseOp(deviceAllocOp);
|
||||
}
|
||||
if (outputAllocOp && outputAllocOp->use_empty())
|
||||
rewriter.eraseOp(outputAllocOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Collapses fill-and-copy allocation chains into one folded constant global.
|
||||
struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::AllocOp allocOp, PatternRewriter& rewriter) const override {
|
||||
auto moduleOp = allocOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto foldedAttr = foldConstantAlloc(allocOp, moduleOp);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
auto allocType = cast<MemRefType>(allocOp.getType());
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_constant");
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
|
||||
SmallVector<Operation*> opsToErase;
|
||||
SmallVector<memref::CastOp> castsToReplace;
|
||||
bool allLiveUsersAreCoreOps = true;
|
||||
for (Operation* user : llvm::make_early_inc_range(allocOp->getUsers())) {
|
||||
if (isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp>(user)) {
|
||||
opsToErase.push_back(user);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(user)) {
|
||||
castsToReplace.push_back(castOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||
return llvm::all_of(castOp->getUsers(),
|
||||
[](Operation* user) { return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user); });
|
||||
})) {
|
||||
allLiveUsersAreCoreOps = false;
|
||||
}
|
||||
|
||||
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||
return isa<linalg::MapOp,
|
||||
memref::SubViewOp,
|
||||
memref::DeallocOp,
|
||||
memref::CastOp,
|
||||
pim::PimCoreOp,
|
||||
pim::PimCoreBatchOp>(user);
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (allLiveUsersAreCoreOps) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> preservedUsers(opsToErase.begin(), opsToErase.end());
|
||||
for (memref::CastOp castOp : castsToReplace)
|
||||
preservedUsers.insert(castOp);
|
||||
rewriter.replaceAllUsesExcept(allocOp.getResult(), newGetGlobal.getResult(), preservedUsers);
|
||||
|
||||
for (memref::CastOp castOp : castsToReplace) {
|
||||
rewriter.setInsertionPoint(castOp);
|
||||
Value replacementCast = memref::CastOp::create(rewriter, castOp.getLoc(), castOp.getType(), newGetGlobal);
|
||||
rewriter.replaceOp(castOp, replacementCast);
|
||||
if (allLiveUsersAreCoreOps)
|
||||
markWeightAlways(replacementCast.getDefiningOp());
|
||||
}
|
||||
|
||||
for (Operation* op : llvm::make_early_inc_range(opsToErase)) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
for (Operation* subviewUser : llvm::make_early_inc_range(subviewOp->getUsers()))
|
||||
rewriter.eraseOp(subviewUser);
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts host copies from dense globals into direct folded globals.
|
||||
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts PIM copies from dense globals into direct folded globals before codegen.
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto targetOffset = resolveIndexValue(copyOp.getTargetOffset());
|
||||
auto sourceOffset = resolveIndexValue(copyOp.getSourceOffset());
|
||||
if (failed(targetOffset) || failed(sourceOffset) || *targetOffset != 0 || *sourceOffset != 0)
|
||||
return failure();
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType);
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantHostCopyPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,61 @@
|
||||
#include "../Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// Folds constant subviews used as core weights into standalone globals.
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto resultMemRefType = cast<MemRefType>(subviewOp.getType());
|
||||
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
|
||||
if (failed(foldedAttr))
|
||||
return failure();
|
||||
|
||||
auto newGlobal =
|
||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, *foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal =
|
||||
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConstantCoreSubviewPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,9 @@
|
||||
add_pim_library(OMPimHostConstantMaterialization
|
||||
MaterializeHostConstantsPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
+161
@@ -0,0 +1,161 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
bool& hasFailure) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
DominanceInfo dominance(coreOp);
|
||||
SmallVector<Operation*> ops;
|
||||
coreOp.getBody().front().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimHaltOp, scf::YieldOp>(op))
|
||||
ops.push_back(op);
|
||||
});
|
||||
|
||||
for (Operation* op : ops) {
|
||||
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
|
||||
continue;
|
||||
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value originalValue = operand.get();
|
||||
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(originalValue);
|
||||
if (failed(resolvedAddress))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = dyn_cast_or_null<memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto originalType = dyn_cast<MemRefType>(originalValue.getType());
|
||||
if (!originalType || !originalType.hasStaticShape()) {
|
||||
op->emitOpError("host constant materialization requires a static memref operand");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& cachedByOffset = materializedValues[resolvedAddress->base];
|
||||
auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset];
|
||||
auto cachedValue = cachedByType.find(originalType);
|
||||
if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) {
|
||||
operand.set(cachedValue->second);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto type = dyn_cast<ShapedType>(originalValue.getType());
|
||||
auto totalBytes = type ? pim::getCheckedShapedTypeSizeInBytes(type, op, "host constant materialization byte size")
|
||||
: FailureOr<uint64_t>(failure());
|
||||
auto totalBytesAttr =
|
||||
succeeded(totalBytes)
|
||||
? pim::getCheckedI32Attr(rewriter, op, *totalBytes, "host constant materialization byte size")
|
||||
: FailureOr<IntegerAttr>(failure());
|
||||
if (failed(totalBytesAttr)
|
||||
|| failed(pim::checkedSize(resolvedAddress->byteOffset, op, "host constant materialization byte offset"))) {
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType);
|
||||
Value deviceDst = localAlloc;
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
||||
|
||||
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
||||
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
||||
Value copiedValue = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
originalType,
|
||||
zeroOffset,
|
||||
hostOffset,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
*totalBytesAttr)
|
||||
.getOutput();
|
||||
|
||||
cachedByType[originalType] = copiedValue;
|
||||
operand.set(copiedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
OperationFolder constantFolder(moduleOp.getContext());
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
|
||||
materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
|
||||
materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (isa<pim::PimConcatOp>(op))
|
||||
hostCompactOps.push_back(&op);
|
||||
|
||||
for (Operation* op : hostCompactOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto concatOp = cast<pim::PimConcatOp>(op);
|
||||
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(moduleOp, "pim4_materialized");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,14 @@
|
||||
add_pim_library(OMPimMemoryCoalescing
|
||||
MemoryCoalescing.cpp
|
||||
MemoryCoalescing.hpp
|
||||
MemoryCoalescingPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
+29
-11
@@ -10,7 +10,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -32,6 +32,13 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
|
||||
}
|
||||
|
||||
static Operation* getTopLevelAncestorInBody(Operation* op, Block& body) {
|
||||
Operation* current = op;
|
||||
while (current && current->getBlock() != &body)
|
||||
current = current->getParentOp();
|
||||
return current;
|
||||
}
|
||||
|
||||
static FailureOr<uint64_t>
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||
@@ -42,7 +49,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
while (!pendingValues.empty()) {
|
||||
Value value = pendingValues.pop_back_val();
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (user->getBlock() != &body)
|
||||
Operation* orderedUser = getTopLevelAncestorInBody(user, body);
|
||||
if (!orderedUser)
|
||||
return failure();
|
||||
if (!visited.insert(user).second)
|
||||
continue;
|
||||
@@ -51,6 +59,15 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
for (Value result : user->getResults())
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
|
||||
if (!forOp)
|
||||
return failure();
|
||||
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands()))
|
||||
if (operand == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
|
||||
if (initArg == value)
|
||||
@@ -66,7 +83,7 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
}
|
||||
}
|
||||
|
||||
auto order = opOrder.find(user);
|
||||
auto order = opOrder.find(orderedUser);
|
||||
if (order == opOrder.end())
|
||||
return failure();
|
||||
endInstruction = std::max(endInstruction, order->second);
|
||||
@@ -78,8 +95,8 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
|
||||
} // namespace
|
||||
|
||||
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation* coreLikeOp) {
|
||||
StaticMemoryCoalescingAnalysis analysis;
|
||||
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation* coreLikeOp) {
|
||||
MemoryCoalescingAnalysis analysis;
|
||||
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
|
||||
return analysis;
|
||||
|
||||
@@ -107,18 +124,19 @@ StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation
|
||||
}
|
||||
|
||||
analysis.candidates.push_back(
|
||||
StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
|
||||
AllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
|
||||
}
|
||||
|
||||
return analysis;
|
||||
}
|
||||
|
||||
StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) {
|
||||
StaticMemoryCoalescingStats stats;
|
||||
auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp);
|
||||
MemoryCoalescingStats
|
||||
coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) {
|
||||
MemoryCoalescingStats stats;
|
||||
stats.skippedAllocations = analysis.skippedAllocations;
|
||||
|
||||
llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) {
|
||||
auto candidates = analysis.candidates;
|
||||
llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) {
|
||||
if (lhs.startInstruction != rhs.startInstruction)
|
||||
return lhs.startInstruction < rhs.startInstruction;
|
||||
return lhs.endInstruction < rhs.endInstruction;
|
||||
@@ -132,7 +150,7 @@ StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, Rewriter
|
||||
SmallVector<ActiveStorage> active;
|
||||
SmallVector<memref::AllocOp> freeList;
|
||||
|
||||
for (StaticAllocationCandidate& candidate : analysis.candidates) {
|
||||
for (AllocationCandidate& candidate : candidates) {
|
||||
for (auto it = active.begin(); it != active.end();) {
|
||||
if (it->endInstruction < candidate.startInstruction) {
|
||||
freeList.push_back(it->root);
|
||||
+7
-6
@@ -8,27 +8,28 @@
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
struct StaticAllocationCandidate {
|
||||
struct AllocationCandidate {
|
||||
mlir::memref::AllocOp alloc;
|
||||
uint64_t startInstruction = 0;
|
||||
uint64_t endInstruction = 0;
|
||||
uint64_t sizeBytes = 0;
|
||||
};
|
||||
|
||||
struct StaticMemoryCoalescingAnalysis {
|
||||
llvm::SmallVector<StaticAllocationCandidate> candidates;
|
||||
struct MemoryCoalescingAnalysis {
|
||||
llvm::SmallVector<AllocationCandidate> candidates;
|
||||
uint64_t skippedAllocations = 0;
|
||||
};
|
||||
|
||||
struct StaticMemoryCoalescingStats {
|
||||
struct MemoryCoalescingStats {
|
||||
uint64_t removedAllocs = 0;
|
||||
uint64_t savedBytes = 0;
|
||||
uint64_t skippedAllocations = 0;
|
||||
};
|
||||
|
||||
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
|
||||
MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
|
||||
|
||||
StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter);
|
||||
MemoryCoalescingStats
|
||||
coalesceMemory(mlir::Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, mlir::RewriterBase& rewriter);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
+26
-15
@@ -10,10 +10,11 @@
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -151,34 +152,39 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
file.close();
|
||||
}
|
||||
|
||||
struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass)
|
||||
struct PimMemoryCoalescingPass : PassWrapper<PimMemoryCoalescingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMemoryCoalescingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-static-memory-coalescing"; }
|
||||
StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; }
|
||||
StringRef getArgument() const override { return "pim-memory-coalescing"; }
|
||||
StringRef getDescription() const override { return "Analyze local PIM memory reuse opportunities"; }
|
||||
|
||||
StaticMemoryCoalescingPass() = default;
|
||||
StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {}
|
||||
PimMemoryCoalescingPass() = default;
|
||||
PimMemoryCoalescingPass(const PimMemoryCoalescingPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<CoalescingReportEntry, 32> reportEntries;
|
||||
uint64_t nextBatchId = 0;
|
||||
bool hasFailure = false;
|
||||
|
||||
getOperation().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||
if (hasFailure || !isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||
return;
|
||||
|
||||
auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op);
|
||||
auto stats = pim::coalesceStaticMemory(op, rewriter);
|
||||
auto analysis = pim::analyzeMemoryCoalescingCandidates(op);
|
||||
auto stats = pim::coalesceMemory(op, analysis, rewriter);
|
||||
CoalescingReportRow row {
|
||||
analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
reportEntries.push_back({CoalescingReportEntry::Kind::Core,
|
||||
static_cast<uint64_t>(coreOp.getCoreId()),
|
||||
{static_cast<int32_t>(coreOp.getCoreId())},
|
||||
row});
|
||||
auto checkedCoreId =
|
||||
pim::checkedI32(static_cast<uint64_t>(coreOp.getCoreId()), coreOp, "memory coalescing core id");
|
||||
if (failed(checkedCoreId)) {
|
||||
hasFailure = true;
|
||||
return;
|
||||
}
|
||||
reportEntries.push_back(
|
||||
{CoalescingReportEntry::Kind::Core, static_cast<uint64_t>(coreOp.getCoreId()), {*checkedCoreId}, row});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -191,6 +197,11 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
|
||||
reportEntries.push_back(std::move(entry));
|
||||
});
|
||||
|
||||
if (hasFailure) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
emitReport(reportEntries);
|
||||
dumpModule(getOperation(), "pim2_coalesced");
|
||||
}
|
||||
@@ -198,6 +209,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
|
||||
std::unique_ptr<Pass> createPimMemoryCoalescingPass() { return std::make_unique<PimMemoryCoalescingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,14 +0,0 @@
|
||||
add_pim_library(OMPimStaticMemoryCoalescing
|
||||
StaticMemoryCoalescing.cpp
|
||||
StaticMemoryCoalescing.hpp
|
||||
StaticMemoryCoalescingPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
add_pim_library(OMPimVerification
|
||||
VerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
PimOps
|
||||
SpatialOps
|
||||
)
|
||||
@@ -0,0 +1,434 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isAddressOnlyHostOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
memref::AllocOp,
|
||||
memref::GetGlobalOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
memref::CopyOp>(op);
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (succeeded(resolvedAddress))
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(compiledAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
if (succeeded(resolvedAddress))
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
|
||||
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||
if (failed(compiledAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(compiledAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
if (!defOp)
|
||||
return false;
|
||||
if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = getGlobalOp->getParentOfType<ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
return globalOp && globalOp.getConstant() && globalOp.getInitialValue()
|
||||
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
}
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return false;
|
||||
value = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
value = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool isCoreWeightBlockArgument(Value value) {
|
||||
auto blockArgument = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArgument)
|
||||
return false;
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(blockArgument.getOwner()->getParentOp()))
|
||||
return static_cast<unsigned>(blockArgument.getArgNumber()) < coreOp.getWeights().size();
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(blockArgument.getOwner()->getParentOp())) {
|
||||
unsigned argNumber = static_cast<unsigned>(blockArgument.getArgNumber());
|
||||
return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||
return isa<pim::PimMemCopyHostToDevOp,
|
||||
pim::PimMemCopyDevToHostOp,
|
||||
pim::PimMemCopyOp,
|
||||
pim::PimReceiveOp,
|
||||
pim::PimSendOp,
|
||||
pim::PimConcatOp,
|
||||
pim::PimVMMOp,
|
||||
pim::PimTransposeOp,
|
||||
pim::PimVVAddOp,
|
||||
pim::PimVVSubOp,
|
||||
pim::PimVVMulOp,
|
||||
pim::PimVVMaxOp,
|
||||
pim::PimVVDMulOp,
|
||||
pim::PimVAvgOp,
|
||||
pim::PimVReluOp,
|
||||
pim::PimVTanhOp,
|
||||
pim::PimVSigmOp,
|
||||
pim::PimVSoftmaxOp,
|
||||
memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
|
||||
}
|
||||
|
||||
VerificationPass() {}
|
||||
VerificationPass(const VerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
});
|
||||
});
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
|
||||
StaticValueKnowledge knowledge;
|
||||
(void) verifyCoreLikeOperands(coreOp, knowledge, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||
llvm::SmallVector<unsigned, 2> lanes;
|
||||
lanes.push_back(0);
|
||||
if (coreBatchOp.getLaneCount() > 1)
|
||||
lanes.push_back(static_cast<unsigned>(coreBatchOp.getLaneCount() - 1));
|
||||
for (unsigned lane : lanes) {
|
||||
StaticValueKnowledge knowledge;
|
||||
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
|
||||
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
|
||||
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
|
||||
(void) verifyCoreLikeOperands(coreBatchOp, knowledge, diagnostics);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
(void) verifyReturnOp(returnOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
(void) verifyAddressOnlyHostOp(&op, diagnostics);
|
||||
}
|
||||
}
|
||||
|
||||
if (diagnostics.hasFailure()) {
|
||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult
|
||||
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto it : llvm::enumerate(coreOp.getWeights())) {
|
||||
size_t weightIndex = it.index();
|
||||
Value weight = it.value();
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before "
|
||||
"JSON codegen";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!getGlobalOp)
|
||||
continue;
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
size_t resultIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
static LogicalResult verifyCoreLikeOperands(CoreLikeOpTy coreLikeOp,
|
||||
const StaticValueKnowledge& initialKnowledge,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
return walkPimCoreBlockStructurally(
|
||||
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
if (!isSupportedCoreInstructionOp(&op)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(op.getOperands())) {
|
||||
size_t operandIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
if (isCoreWeightBlockArgument(operand))
|
||||
continue;
|
||||
|
||||
if (auto vmmOp = dyn_cast<pim::PimVMMOp>(&op);
|
||||
vmmOp && operandIndex == 0 && resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress) && failed(compileContiguousAddressExpr(operand))) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostMemCopyOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Value addressBase =
|
||||
succeeded(resolvedAddress) ? resolvedAddress->base : compileContiguousAddressExpr(operand)->base;
|
||||
if (!isa<memref::AllocOp>(addressBase.getDefiningOp())) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with "
|
||||
"pim.memcp_hd";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(storeOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(loadOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(op)) {
|
||||
if (!pim::isNormalizedCopyOp(copyOp)) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("must use base memref operands plus explicit byte offsets after bufferization");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
if (failed(resolveIndexValue(copyOp.getTargetOffset(), knowledge))
|
||||
|| failed(resolveIndexValue(copyOp.getSourceOffset(), knowledge))) {
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||
if (!isMemRefBaseAddressableValue(copyOp.getSource()) || !isMemRefBaseAddressableValue(copyOp.getTarget())) {
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isMemRefBaseAddressableValue(source))
|
||||
return success();
|
||||
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<VerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+412
-341
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,7 @@
|
||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -128,8 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||
if (failed(checkedCoreId))
|
||||
return std::nullopt;
|
||||
return *checkedCoreId;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user