This commit is contained in:
@@ -198,7 +198,6 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
static std::optional<CompileTimeSource>
|
||||
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||
if (!op)
|
||||
@@ -217,7 +216,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
||||
chainLength += 1;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
|
||||
return hasConstantIndices(extractOp)
|
||||
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return std::nullopt;
|
||||
@@ -232,8 +233,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
||||
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
return hasStaticUnitStrides(extractSliceOp)
|
||||
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
@@ -252,9 +254,8 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
||||
res = partialRes;
|
||||
continue;
|
||||
}
|
||||
if(res->chainLength < partialRes->chainLength){
|
||||
if (res->chainLength < partialRes->chainLength)
|
||||
res = partialRes;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -264,8 +265,7 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(op, visited);
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
@@ -143,13 +143,12 @@ static Value createGemmBatchKOffset(
|
||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value createGemmBatchHOffset(
|
||||
Value lane,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value createGemmBatchHOffset(Value lane,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (numOutHSlices == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user