refactor Pim constant folding pass
share contiguous address resolution in PimCommon group patterns in subdir for each pass with pattern files
This commit is contained in:
@@ -20,7 +20,10 @@ add_onnx_mlir_library(OMPIMAccel
|
|||||||
Pass/CountInstructionPass.cpp
|
Pass/CountInstructionPass.cpp
|
||||||
Pass/EmitPimJsonPass.cpp
|
Pass/EmitPimJsonPass.cpp
|
||||||
Pass/MessagePass.cpp
|
Pass/MessagePass.cpp
|
||||||
Pass/PimConstantFoldingPass.cpp
|
Pass/PimConstantFolding/Common.cpp
|
||||||
|
Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp
|
||||||
|
Pass/PimConstantFolding/PimConstantFoldingPass.cpp
|
||||||
|
Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
|
||||||
Pass/PimHostVerificationPass.cpp
|
Pass/PimHostVerificationPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return ResolvedContiguousAddress{value, byteOffset};
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||||
|
if (!tiedOperand)
|
||||||
|
return failure();
|
||||||
|
value = tiedOperand->get();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||||
|
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||||
|
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||||
|
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||||
|
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||||
|
return failure();
|
||||||
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||||
|
return ResolvedContiguousAddress{value, byteOffset};
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct ResolvedContiguousAddress {
|
||||||
|
mlir::Value base;
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
std::string getOutputDir();
|
std::string getOutputDir();
|
||||||
|
|
||||||
void createDirectory(const std::string& directory);
|
void createDirectory(const std::string& directory);
|
||||||
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
llvm::ArrayRef<int64_t> strides);
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -86,48 +86,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||||
size_t offset = 0;
|
auto resolvedAddress = resolveContiguousAddress(value);
|
||||||
while (true) {
|
if (failed(resolvedAddress)) {
|
||||||
auto definingOp = value.getDefiningOp();
|
errs() << "Failed to resolve contiguous address for value: ";
|
||||||
if (!definingOp)
|
|
||||||
break;
|
|
||||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
||||||
if (!tiedOperand)
|
|
||||||
break;
|
|
||||||
value = tiedOperand->get();
|
|
||||||
}
|
|
||||||
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
|
||||||
auto source = subviewDefiningOp.getSource();
|
|
||||||
auto srcShape = source.getType().getShape();
|
|
||||||
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
|
|
||||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
|
||||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
|
||||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
|
||||||
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
|
||||||
size_t localOffset = subviewOffsets[i];
|
|
||||||
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
|
||||||
localOffset *= subviewSizes[j];
|
|
||||||
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
|
||||||
}
|
|
||||||
value = source;
|
|
||||||
}
|
|
||||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
|
||||||
value = castOp.getSource();
|
|
||||||
}
|
|
||||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseOp.getSrc();
|
|
||||||
}
|
|
||||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandOp.getSrc();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto iter = memEntriesMap.find(value);
|
|
||||||
if (iter == memEntriesMap.end()) {
|
|
||||||
errs() << "Missing mem entry for value: ";
|
|
||||||
value.print(errs());
|
value.print(errs());
|
||||||
errs() << "\n";
|
errs() << "\n";
|
||||||
if (auto* definingOp = value.getDefiningOp()) {
|
if (auto* definingOp = value.getDefiningOp()) {
|
||||||
@@ -135,10 +96,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
|||||||
definingOp->print(errs());
|
definingOp->print(errs());
|
||||||
errs() << "\n";
|
errs() << "\n";
|
||||||
}
|
}
|
||||||
|
llvm_unreachable("Failed to resolve contiguous address");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = memEntriesMap.find(resolvedAddress->base);
|
||||||
|
if (iter == memEntriesMap.end()) {
|
||||||
|
errs() << "Missing mem entry for value: ";
|
||||||
|
resolvedAddress->base.print(errs());
|
||||||
|
errs() << "\n";
|
||||||
|
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
||||||
|
errs() << "Defining op:\n";
|
||||||
|
definingOp->print(errs());
|
||||||
|
errs() << "\n";
|
||||||
|
}
|
||||||
llvm_unreachable("Missing mem entry");
|
llvm_unreachable("Missing mem entry");
|
||||||
}
|
}
|
||||||
|
|
||||||
return iter->second.address + offset;
|
return iter->second.address + resolvedAddress->byteOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
json::Object PimCodeGen::createEmptyOffset() {
|
json::Object PimCodeGen::createEmptyOffset() {
|
||||||
|
|||||||
@@ -3,19 +3,19 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_onnx_mlir_library(OMONNXToSpatial
|
add_onnx_mlir_library(OMONNXToSpatial
|
||||||
Math/Gemm.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Math/MatMul.cpp
|
Patterns/Math/MatMul.cpp
|
||||||
NN/Pooling.cpp
|
Patterns/NN/Pooling.cpp
|
||||||
NN/ReduceMean.cpp
|
Patterns/NN/ReduceMean.cpp
|
||||||
Tensor/ONNXConcatToTensorConcat.cpp
|
Patterns/Tensor/ONNXConcatToTensorConcat.cpp
|
||||||
Tensor/ONNXReshapeToTensorReshape.cpp
|
Patterns/Tensor/ONNXReshapeToTensorReshape.cpp
|
||||||
Tensor/RemoveUnusedHelperOps.cpp
|
Patterns/Tensor/RemoveUnusedHelperOps.cpp
|
||||||
Utils/SpatialReducer.cpp
|
Utils/SpatialReducer.cpp
|
||||||
Utils/WeightSubdivider.cpp
|
Utils/WeightSubdivider.cpp
|
||||||
Utils/AnnotateReplication.cpp
|
Utils/AnnotateReplication.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
ONNXToSpatialCommon.cpp
|
Common.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
ONNXToSpatialIncGen
|
ONNXToSpatialIncGen
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "ONNXToSpatialCommon.hpp"
|
#include "Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -11,7 +11,7 @@
|
|||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
|||||||
|
|
||||||
add_onnx_mlir_library(OMSpatialToPim
|
add_onnx_mlir_library(OMSpatialToPim
|
||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
SpatialToPimCommon.cpp
|
Common.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
SpatialToPimIncGen
|
SpatialToPimIncGen
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "SpatialToPimCommon.hpp"
|
#include "Common.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -19,9 +19,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
namespace spatial {
|
|
||||||
|
|
||||||
// TODO: Add here eventual patterns
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -24,7 +24,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
|||||||
121
src/PIM/Pass/PimConstantFolding/Common.cpp
Normal file
121
src/PIM/Pass/PimConstantFolding/Common.cpp
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#include "Common.hpp"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value stripMemRefCasts(Value value) {
|
||||||
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
|
value = castOp.getSource();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value stripMemRefViewOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||||
|
Location loc,
|
||||||
|
MemRefType globalType,
|
||||||
|
DenseElementsAttr denseAttr,
|
||||||
|
StringRef nameStem,
|
||||||
|
IntegerAttr alignment) {
|
||||||
|
auto globalName = nameStem.str();
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (moduleOp.lookupSymbol(globalName))
|
||||||
|
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||||
|
|
||||||
|
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||||
|
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||||
|
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
return memref::GlobalOp::create(moduleBuilder,
|
||||||
|
loc,
|
||||||
|
globalName,
|
||||||
|
visibility,
|
||||||
|
globalType,
|
||||||
|
denseAttr,
|
||||||
|
/*constant=*/true,
|
||||||
|
alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||||
|
value = stripMemRefCasts(value);
|
||||||
|
|
||||||
|
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
return denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
value = stripMemRefViewOps(value);
|
||||||
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
StaticSubviewInfo info;
|
||||||
|
info.source = source;
|
||||||
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
info.offsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto staticSize = getConstantIntValue(size);
|
||||||
|
if (!staticSize)
|
||||||
|
return failure();
|
||||||
|
info.sizes.push_back(*staticSize);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
info.strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||||
|
ArrayRef<int64_t> outerIndices,
|
||||||
|
int64_t elementByteWidth) {
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(info.sourceShape.size());
|
||||||
|
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||||
|
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||||
|
sourceIndices.push_back(info.offsets.back());
|
||||||
|
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
41
src/PIM/Pass/PimConstantFolding/Common.hpp
Normal file
41
src/PIM/Pass/PimConstantFolding/Common.hpp
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct StaticSubviewInfo {
|
||||||
|
mlir::Value source;
|
||||||
|
llvm::SmallVector<int64_t> sourceShape;
|
||||||
|
llvm::SmallVector<int64_t> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::MemRefType globalType,
|
||||||
|
mlir::DenseElementsAttr denseAttr,
|
||||||
|
llvm::StringRef nameStem,
|
||||||
|
mlir::IntegerAttr alignment = {});
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||||
|
|
||||||
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
|
|
||||||
|
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||||
|
llvm::ArrayRef<int64_t> outerIndices,
|
||||||
|
int64_t elementByteWidth);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
11
src/PIM/Pass/PimConstantFolding/Patterns.hpp
Normal file
11
src/PIM/Pass/PimConstantFolding/Patterns.hpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,19 +1,11 @@
|
|||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "../Common.hpp"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "../Patterns.hpp"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -21,73 +13,14 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Value stripMemRefCasts(Value value) {
|
struct ConstantSubviewCopy {
|
||||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
DenseElementsAttr source;
|
||||||
value = castOp.getSource();
|
SmallVector<int64_t> offsets;
|
||||||
return value;
|
SmallVector<int64_t> strides;
|
||||||
}
|
Operation* copyOp = nullptr;
|
||||||
|
};
|
||||||
static Value stripMemRefViewOps(Value value) {
|
|
||||||
while (true) {
|
|
||||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
|
||||||
value = castOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
|
||||||
value = collapseOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
|
||||||
value = expandOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
|
||||||
Location loc,
|
|
||||||
MemRefType globalType,
|
|
||||||
DenseElementsAttr denseAttr,
|
|
||||||
StringRef nameStem,
|
|
||||||
IntegerAttr alignment = {}) {
|
|
||||||
auto globalName = nameStem.str();
|
|
||||||
unsigned suffix = 0;
|
|
||||||
while (moduleOp.lookupSymbol(globalName))
|
|
||||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
|
||||||
|
|
||||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
|
||||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
|
||||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
|
||||||
return memref::GlobalOp::create(moduleBuilder,
|
|
||||||
loc,
|
|
||||||
globalName,
|
|
||||||
visibility,
|
|
||||||
globalType,
|
|
||||||
denseAttr,
|
|
||||||
/*constant=*/true,
|
|
||||||
alignment);
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
|
||||||
value = stripMemRefCasts(value);
|
|
||||||
|
|
||||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
|
||||||
if (!getGlobalOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
||||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
|
||||||
if (!denseAttr)
|
|
||||||
return failure();
|
|
||||||
return denseAttr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
@@ -144,13 +77,6 @@ static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr den
|
|||||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConstantSubviewCopy {
|
|
||||||
DenseElementsAttr source;
|
|
||||||
SmallVector<int64_t> offsets;
|
|
||||||
SmallVector<int64_t> strides;
|
|
||||||
Operation* copyOp = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||||
if (!mapOp.getInputs().empty())
|
if (!mapOp.getInputs().empty())
|
||||||
return failure();
|
return failure();
|
||||||
@@ -213,204 +139,6 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct StaticSubviewInfo {
|
|
||||||
Value source;
|
|
||||||
SmallVector<int64_t> sourceShape;
|
|
||||||
SmallVector<int64_t> offsets;
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
SmallVector<int64_t> strides;
|
|
||||||
};
|
|
||||||
|
|
||||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
|
||||||
value = stripMemRefViewOps(value);
|
|
||||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
|
||||||
if (!subviewOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
|
||||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
StaticSubviewInfo info;
|
|
||||||
info.source = source;
|
|
||||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
|
||||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
||||||
auto staticOffset = getConstantIntValue(offset);
|
|
||||||
if (!staticOffset)
|
|
||||||
return failure();
|
|
||||||
info.offsets.push_back(*staticOffset);
|
|
||||||
}
|
|
||||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
auto staticSize = getConstantIntValue(size);
|
|
||||||
if (!staticSize)
|
|
||||||
return failure();
|
|
||||||
info.sizes.push_back(*staticSize);
|
|
||||||
}
|
|
||||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
auto staticStride = getConstantIntValue(stride);
|
|
||||||
if (!staticStride)
|
|
||||||
return failure();
|
|
||||||
info.strides.push_back(*staticStride);
|
|
||||||
}
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t
|
|
||||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
|
||||||
SmallVector<int64_t> sourceIndices;
|
|
||||||
sourceIndices.reserve(info.sourceShape.size());
|
|
||||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
|
||||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
|
||||||
sourceIndices.push_back(info.offsets.back());
|
|
||||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
|
||||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
|
||||||
const bool splitSrc = succeeded(srcSubview)
|
|
||||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
|
||||||
const bool splitDst = succeeded(dstSubview)
|
|
||||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
|
||||||
if (!splitSrc && !splitDst)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
|
||||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
|
||||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
if (sourceType.getElementType() != dstType.getElementType())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return failure();
|
|
||||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
|
||||||
if (elementByteWidth <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
||||||
if (copyOp.getSize() != totalBytes)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
||||||
if (sliceBytes <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(copyOp);
|
|
||||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
||||||
SmallVector<int64_t> outerIndices =
|
|
||||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
||||||
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
|
||||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
|
||||||
: linearIndex * sliceBytes);
|
|
||||||
const int64_t dstByteOffset = copyOp.getDstOffset()
|
|
||||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
|
||||||
: linearIndex * sliceBytes);
|
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
|
||||||
copyOp.getLoc(),
|
|
||||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
|
||||||
splitDst ? dstSubview->source : copyOp.getDst(),
|
|
||||||
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
|
||||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
|
||||||
const bool splitSrc = succeeded(srcSubview)
|
|
||||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
|
||||||
const bool splitDst = succeeded(dstSubview)
|
|
||||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
|
||||||
if (!splitSrc && !splitDst)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
|
||||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
|
||||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
if (sourceType.getElementType() != dstType.getElementType())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return failure();
|
|
||||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
|
||||||
if (elementByteWidth <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
||||||
if (copyOp.getSize() != totalBytes)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
||||||
if (sliceBytes <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(copyOp);
|
|
||||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
||||||
SmallVector<int64_t> outerIndices =
|
|
||||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
||||||
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
|
||||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
|
||||||
: linearIndex * sliceBytes);
|
|
||||||
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
|
||||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
|
||||||
: linearIndex * sliceBytes);
|
|
||||||
pim::PimMemCopyHostToDevOp::create(
|
|
||||||
rewriter,
|
|
||||||
copyOp.getLoc(),
|
|
||||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
|
||||||
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
|
||||||
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
if (!allocType || !allocType.hasStaticShape())
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
@@ -563,17 +291,15 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
MemRefType globalType = resultType;
|
|
||||||
|
|
||||||
auto newGlobal = createFoldedGlobal(moduleOp,
|
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||||
transposeOp.getLoc(),
|
transposeOp.getLoc(),
|
||||||
globalType,
|
resultType,
|
||||||
*transposedAttr,
|
*transposedAttr,
|
||||||
sourceGlobal.getName().str() + "__folded_transpose",
|
sourceGlobal.getName().str() + "__folded_transpose",
|
||||||
sourceGlobal.getAlignmentAttr());
|
sourceGlobal.getAlignmentAttr());
|
||||||
|
|
||||||
rewriter.setInsertionPoint(transposeOp);
|
rewriter.setInsertionPoint(transposeOp);
|
||||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||||
|
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
!transposeOp->getUsers().empty()
|
!transposeOp->getUsers().empty()
|
||||||
@@ -672,11 +398,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
// Only match top-level memcp (not inside pim.core)
|
|
||||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// dst must be an alloc with static shape
|
|
||||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||||
if (!allocOp)
|
if (!allocOp)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -684,11 +408,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
if (!allocType || !allocType.hasStaticShape())
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// The copy must cover the full destination (offsets both zero)
|
|
||||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Resolve the source through an optional subview to a get_global
|
|
||||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||||
|
|
||||||
@@ -700,14 +422,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
if (failed(denseAttr))
|
if (failed(denseAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Build the folded dense attribute
|
|
||||||
DenseElementsAttr foldedAttr;
|
DenseElementsAttr foldedAttr;
|
||||||
if (succeeded(srcSubview)) {
|
if (succeeded(srcSubview)) {
|
||||||
// Extract the sub-tensor from the source constant
|
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
@@ -729,14 +449,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Direct copy from a global — just reuse its dense attribute
|
|
||||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
if (resultTensorType != denseAttr->getType())
|
if (resultTensorType != denseAttr->getType())
|
||||||
return failure();
|
return failure();
|
||||||
foldedAttr = *denseAttr;
|
foldedAttr = *denseAttr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that the alloc's remaining users are supported ops.
|
|
||||||
bool allLiveUsersAreCores = true;
|
bool allLiveUsersAreCores = true;
|
||||||
for (Operation* user : allocOp->getUsers()) {
|
for (Operation* user : allocOp->getUsers()) {
|
||||||
if (user == copyOp)
|
if (user == copyOp)
|
||||||
@@ -769,110 +487,13 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
|
||||||
// Only handle subviews whose users are all pim.core ops.
|
|
||||||
if (subviewOp.use_empty())
|
|
||||||
return failure();
|
|
||||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Source must resolve to a constant get_global.
|
|
||||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
|
||||||
if (!moduleOp)
|
|
||||||
return failure();
|
|
||||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
|
||||||
if (failed(denseAttr))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Static subview info.
|
|
||||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
|
||||||
if (failed(subviewInfo))
|
|
||||||
return failure();
|
|
||||||
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Build the contiguous result type.
|
|
||||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
|
||||||
auto resultMemRefType = MemRefType::get(
|
|
||||||
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
|
||||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
|
||||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
|
||||||
|
|
||||||
// Extract the sub-tensor.
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
||||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
|
||||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
|
||||||
SmallVector<Attribute> resultValues(numResultElements);
|
|
||||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
|
||||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
|
||||||
SmallVector<int64_t> sourceIndices;
|
|
||||||
sourceIndices.reserve(resultIndices.size());
|
|
||||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
|
||||||
sourceIndices.push_back(off + idx);
|
|
||||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
|
||||||
}
|
|
||||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
|
||||||
|
|
||||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
|
||||||
markWeightAlways(newGlobal);
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(subviewOp);
|
|
||||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
|
||||||
markWeightAlways(newGetGlobal);
|
|
||||||
|
|
||||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
|
||||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
|
||||||
|
|
||||||
LogicalResult initialize(MLIRContext* context) override {
|
|
||||||
RewritePatternSet owningPatterns(context);
|
|
||||||
for (auto* dialect : context->getLoadedDialects())
|
|
||||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
|
||||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
|
||||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
|
||||||
owningPatterns
|
|
||||||
.add<FoldConstantTransposePattern,
|
|
||||||
FoldConstantAllocPattern,
|
|
||||||
FoldConstantCoreMapPattern,
|
|
||||||
RewriteCoreSubviewCopyPattern,
|
|
||||||
RewriteHostSubviewLoadPattern,
|
|
||||||
FoldConstantMemCpPattern,
|
|
||||||
FoldConstantCoreSubviewPattern>(
|
|
||||||
context);
|
|
||||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
|
||||||
GreedyRewriteConfig config;
|
|
||||||
config.enableFolding();
|
|
||||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dumpModule(getOperation(), "pim2_folded");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<FoldConstantTransposePattern,
|
||||||
|
FoldConstantAllocPattern,
|
||||||
|
FoldConstantCoreMapPattern,
|
||||||
|
FoldConstantMemCpPattern>(patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
223
src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
Normal file
223
src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
#include "../Common.hpp"
|
||||||
|
#include "../Patterns.hpp"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename CopyOp, typename CreateCopyOp>
|
||||||
|
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||||
|
Value dst,
|
||||||
|
Value src,
|
||||||
|
int64_t dstOffset,
|
||||||
|
int64_t srcOffset,
|
||||||
|
int64_t size,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
CreateCopyOp createCopyOp) {
|
||||||
|
auto srcSubview = getStaticSubviewInfo(src);
|
||||||
|
auto dstSubview = getStaticSubviewInfo(dst);
|
||||||
|
const bool splitSrc = succeeded(srcSubview)
|
||||||
|
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||||
|
const bool splitDst = succeeded(dstSubview)
|
||||||
|
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||||
|
if (!splitSrc && !splitDst)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||||
|
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||||
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != dstType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||||
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||||
|
if (elementByteWidth <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
|
if (size != totalBytes)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||||
|
if (sliceBytes <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||||
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(copyOp);
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||||
|
SmallVector<int64_t> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||||
|
const int64_t srcByteOffset = srcOffset
|
||||||
|
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
const int64_t dstByteOffset = dstOffset
|
||||||
|
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||||
|
splitDst ? dstSubview->source : dst,
|
||||||
|
splitSrc ? srcSubview->source : src,
|
||||||
|
dstByteOffset,
|
||||||
|
srcByteOffset,
|
||||||
|
sliceBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto status =
|
||||||
|
rewriteSubviewCopyLikeOp(copyOp,
|
||||||
|
copyOp.getDst(),
|
||||||
|
copyOp.getSrc(),
|
||||||
|
copyOp.getDstOffset(),
|
||||||
|
copyOp.getSrcOffset(),
|
||||||
|
copyOp.getSize(),
|
||||||
|
rewriter,
|
||||||
|
[&](MemRefType resultType,
|
||||||
|
Value dst,
|
||||||
|
Value src,
|
||||||
|
int64_t dstByteOffset,
|
||||||
|
int64_t srcByteOffset,
|
||||||
|
int64_t sliceBytes) {
|
||||||
|
pim::PimMemCopyOp::create(
|
||||||
|
rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
});
|
||||||
|
if (failed(status))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto status =
|
||||||
|
rewriteSubviewCopyLikeOp(copyOp,
|
||||||
|
copyOp.getDeviceDst(),
|
||||||
|
copyOp.getHostSrc(),
|
||||||
|
copyOp.getDeviceDstOffset(),
|
||||||
|
copyOp.getHostSrcOffset(),
|
||||||
|
copyOp.getSize(),
|
||||||
|
rewriter,
|
||||||
|
[&](MemRefType resultType,
|
||||||
|
Value dst,
|
||||||
|
Value src,
|
||||||
|
int64_t dstByteOffset,
|
||||||
|
int64_t srcByteOffset,
|
||||||
|
int64_t sliceBytes) {
|
||||||
|
pim::PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
});
|
||||||
|
if (failed(status))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (subviewOp.use_empty())
|
||||||
|
return failure();
|
||||||
|
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||||
|
if (failed(subviewInfo))
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||||
|
auto resultMemRefType =
|
||||||
|
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||||
|
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||||
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(numResultElements);
|
||||||
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||||
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultIndices.size());
|
||||||
|
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||||
|
sourceIndices.push_back(off + idx);
|
||||||
|
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||||
|
}
|
||||||
|
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(subviewOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
|
||||||
|
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
53
src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp
Normal file
53
src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#include "Patterns.hpp"
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||||
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
|
|
||||||
|
LogicalResult initialize(MLIRContext* context) override {
|
||||||
|
RewritePatternSet owningPatterns(context);
|
||||||
|
for (auto* dialect : context->getLoadedDialects())
|
||||||
|
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||||
|
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||||
|
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||||
|
|
||||||
|
populateConstantFoldingConstantPatterns(owningPatterns);
|
||||||
|
populateConstantFoldingSubviewPatterns(owningPatterns);
|
||||||
|
|
||||||
|
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
config.enableFolding();
|
||||||
|
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dumpModule(getOperation(), "pim2_folded");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
@@ -26,37 +27,12 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
|||||||
spatial::SpatChannelNewOp>(op);
|
spatial::SpatChannelNewOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isHostAddressableValue(Value value) {
|
static bool isCodegenAddressableValue(Value value) {
|
||||||
while (true) {
|
auto resolvedAddress = resolveContiguousAddress(value);
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
if (failed(resolvedAddress))
|
||||||
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
|
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
|
||||||
value = subviewOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
|
||||||
value = castOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||||
@@ -80,7 +56,7 @@ struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationP
|
|||||||
|
|
||||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||||
if (failed(verifyCoreWeights(moduleOp, coreOp)))
|
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -139,8 +115,27 @@ private:
|
|||||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||||
if (!isHostAddressableValue(operand)) {
|
if (!isCodegenAddressableValue(operand)) {
|
||||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
|
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
for (Operation& op : coreOp.getBody().front()) {
|
||||||
|
if (isa<pim::PimHaltOp>(op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||||
|
if (!isa<BaseMemRefType>(operand.getType()))
|
||||||
|
continue;
|
||||||
|
if (succeeded(resolveContiguousAddress(operand)))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,10 +155,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||||
if (isHostAddressableValue(source))
|
if (isCodegenAddressableValue(source))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
op->emitOpError("depends on a value that still requires host-side execution");
|
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user