reformat code
This commit is contained in:
@@ -244,7 +244,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
@@ -293,7 +293,7 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
}
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -389,8 +389,8 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||
SmallVector<int64_t> perm =
|
||||
map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(), [](auto attr) -> int64_t { return attr.getInt(); });
|
||||
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
|
||||
[](auto attr) -> int64_t { return attr.getInt(); });
|
||||
|
||||
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
||||
SmallVector<int64_t> dstShape(rank);
|
||||
|
||||
@@ -70,9 +70,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
|
||||
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
|
||||
});
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
|
||||
@@ -58,8 +58,14 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rhsRowType,
|
||||
rhsSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
@@ -89,10 +95,15 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
gemmExpandedType,
|
||||
gemmOut,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
|
||||
@@ -114,8 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<Reshape>(ctx);
|
||||
}
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -186,7 +186,8 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct BinaryDstOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
struct BinaryDstOpBufferizeInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<BinaryDstOpBufferizeInterface<OpTy>, OpTy> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "Common.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -107,9 +106,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
return info;
|
||||
}
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth) {
|
||||
int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#include "Patterns.hpp"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -478,10 +477,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
patterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -20,9 +19,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
const bool splitSrc =
|
||||
succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
const bool splitDst =
|
||||
succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
@@ -62,13 +63,13 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset =
|
||||
srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset =
|
||||
dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : dst,
|
||||
splitSrc ? srcSubview->source : src,
|
||||
@@ -87,22 +88,17 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getTarget(),
|
||||
copyOp.getSource(),
|
||||
copyOp.getTargetOffset(),
|
||||
copyOp.getSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
@@ -123,22 +119,17 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getDeviceTarget(),
|
||||
copyOp.getHostSource(),
|
||||
copyOp.getDeviceTargetOffset(),
|
||||
copyOp.getHostSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
@@ -201,11 +192,13 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
auto newGlobal =
|
||||
createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
auto newGetGlobal =
|
||||
memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
|
||||
@@ -98,16 +98,15 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
||||
if (contiguousType != originalType)
|
||||
deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc);
|
||||
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
op.getLoc(),
|
||||
originalType,
|
||||
deviceDst,
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(
|
||||
static_cast<int32_t>(totalBytes)));
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)));
|
||||
|
||||
cachedByType[originalType] = hostToDevCopy.getResult();
|
||||
operand.set(hostToDevCopy.getResult());
|
||||
@@ -127,8 +126,6 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMaterializeConstantsPass() {
|
||||
return std::make_unique<MaterializeConstantsPass>();
|
||||
}
|
||||
std::unique_ptr<Pass> createMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user