refactor Pim constant folding pass
share contiguous address resolution in PimCommon group patterns in subdir for each pass with pattern files
This commit is contained in:
@@ -20,7 +20,10 @@ add_onnx_mlir_library(OMPIMAccel
|
||||
Pass/CountInstructionPass.cpp
|
||||
Pass/EmitPimJsonPass.cpp
|
||||
Pass/MessagePass.cpp
|
||||
Pass/PimConstantFoldingPass.cpp
|
||||
Pass/PimConstantFolding/Common.cpp
|
||||
Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp
|
||||
Pass/PimConstantFolding/PimConstantFoldingPass.cpp
|
||||
Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
|
||||
Pass/PimHostVerificationPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
@@ -236,4 +239,64 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
int64_t byteOffset = 0;
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = tiedOperand->get();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||
return failure();
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress{value, byteOffset};
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -17,6 +17,11 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
@@ -48,4 +53,6 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -86,48 +86,9 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
}
|
||||
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
size_t offset = 0;
|
||||
while (true) {
|
||||
auto definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
break;
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
break;
|
||||
value = tiedOperand->get();
|
||||
}
|
||||
else if (auto subviewDefiningOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto source = subviewDefiningOp.getSource();
|
||||
auto srcShape = source.getType().getShape();
|
||||
auto subviewOffsets = subviewDefiningOp.getStaticOffsets();
|
||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
||||
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
||||
size_t localOffset = subviewOffsets[i];
|
||||
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
||||
localOffset *= subviewSizes[j];
|
||||
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
}
|
||||
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
}
|
||||
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(value);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress)) {
|
||||
errs() << "Failed to resolve contiguous address for value: ";
|
||||
value.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = value.getDefiningOp()) {
|
||||
@@ -135,10 +96,23 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Failed to resolve contiguous address");
|
||||
}
|
||||
|
||||
auto iter = memEntriesMap.find(resolvedAddress->base);
|
||||
if (iter == memEntriesMap.end()) {
|
||||
errs() << "Missing mem entry for value: ";
|
||||
resolvedAddress->base.print(errs());
|
||||
errs() << "\n";
|
||||
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
||||
errs() << "Defining op:\n";
|
||||
definingOp->print(errs());
|
||||
errs() << "\n";
|
||||
}
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + offset;
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
|
||||
@@ -3,19 +3,19 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.cpp
|
||||
Math/MatMul.cpp
|
||||
NN/Pooling.cpp
|
||||
NN/ReduceMean.cpp
|
||||
Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Tensor/ONNXReshapeToTensorReshape.cpp
|
||||
Tensor/RemoveUnusedHelperOps.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pooling.cpp
|
||||
Patterns/NN/ReduceMean.cpp
|
||||
Patterns/Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Patterns/Tensor/ONNXReshapeToTensorReshape.cpp
|
||||
Patterns/Tensor/RemoveUnusedHelperOps.cpp
|
||||
Utils/SpatialReducer.cpp
|
||||
Utils/WeightSubdivider.cpp
|
||||
Utils/AnnotateReplication.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
ONNXToSpatialCommon.cpp
|
||||
Common.cpp
|
||||
|
||||
DEPENDS
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "ONNXToSpatialCommon.hpp"
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <queue>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -4,7 +4,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
||||
|
||||
add_onnx_mlir_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
SpatialToPimCommon.cpp
|
||||
Common.cpp
|
||||
|
||||
DEPENDS
|
||||
SpatialToPimIncGen
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "SpatialToPimCommon.hpp"
|
||||
#include "Common.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -19,9 +19,9 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
#pragma once
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace spatial {
|
||||
|
||||
// TODO: Add here eventual patterns
|
||||
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -24,7 +24,7 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
|
||||
121
src/PIM/Pass/PimConstantFolding/Common.cpp
Normal file
121
src/PIM/Pass/PimConstantFolding/Common.cpp
Normal file
@@ -0,0 +1,121 @@
|
||||
#include "Common.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment) {
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
41
src/PIM/Pass/PimConstantFolding/Common.hpp
Normal file
41
src/PIM/Pass/PimConstantFolding/Common.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
mlir::Location loc,
|
||||
mlir::MemRefType globalType,
|
||||
mlir::DenseElementsAttr denseAttr,
|
||||
llvm::StringRef nameStem,
|
||||
mlir::IntegerAttr alignment = {});
|
||||
|
||||
llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
llvm::ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
11
src/PIM/Pass/PimConstantFolding/Patterns.hpp
Normal file
11
src/PIM/Pass/PimConstantFolding/Patterns.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConstantFoldingConstantPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,19 +1,11 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -21,73 +13,14 @@
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
static Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
DenseElementsAttr denseAttr,
|
||||
StringRef nameStem,
|
||||
IntegerAttr alignment = {}) {
|
||||
auto globalName = nameStem.str();
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = (nameStem + "_" + std::to_string(++suffix)).str();
|
||||
|
||||
auto visibility = StringAttr::get(moduleOp.getContext(), "private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(moduleBuilder,
|
||||
loc,
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
denseAttr,
|
||||
/*constant=*/true,
|
||||
alignment);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> getDenseGlobalValue(ModuleOp moduleOp, Value value) {
|
||||
value = stripMemRefCasts(value);
|
||||
|
||||
auto getGlobalOp = value.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return failure();
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return denseAttr;
|
||||
}
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
@@ -144,13 +77,6 @@ static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr den
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
struct ConstantSubviewCopy {
|
||||
DenseElementsAttr source;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> strides;
|
||||
Operation* copyOp = nullptr;
|
||||
};
|
||||
|
||||
static FailureOr<Attribute> getConstantMapYield(linalg::MapOp mapOp) {
|
||||
if (!mapOp.getInputs().empty())
|
||||
return failure();
|
||||
@@ -213,204 +139,6 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
Value source;
|
||||
SmallVector<int64_t> sourceShape;
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
static int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
@@ -563,17 +291,15 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
MemRefType globalType = resultType;
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp,
|
||||
transposeOp.getLoc(),
|
||||
globalType,
|
||||
resultType,
|
||||
*transposedAttr,
|
||||
sourceGlobal.getName().str() + "__folded_transpose",
|
||||
sourceGlobal.getAlignmentAttr());
|
||||
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), resultType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
@@ -672,11 +398,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
// Only match top-level memcp (not inside pim.core)
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
// dst must be an alloc with static shape
|
||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
@@ -684,11 +408,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// The copy must cover the full destination (offsets both zero)
|
||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||
return failure();
|
||||
|
||||
// Resolve the source through an optional subview to a get_global
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||
|
||||
@@ -700,14 +422,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Build the folded dense attribute
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
// Extract the sub-tensor from the source constant
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
@@ -729,14 +449,12 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
else {
|
||||
// Direct copy from a global — just reuse its dense attribute
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
// Verify that the alloc's remaining users are supported ops.
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
@@ -769,110 +487,13 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
// Only handle subviews whose users are all pim.core ops.
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
// Source must resolve to a constant get_global.
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Static subview info.
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Build the contiguous result type.
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType = MemRefType::get(
|
||||
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
// Extract the sub-tensor.
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
FoldConstantMemCpPattern,
|
||||
FoldConstantCoreSubviewPattern>(
|
||||
context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
FoldConstantMemCpPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
223
src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
Normal file
223
src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (size != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : dst,
|
||||
splitSrc ? srcSubview->source : src,
|
||||
dstByteOffset,
|
||||
srcByteOffset,
|
||||
sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
copyOp.getDst(),
|
||||
copyOp.getSrc(),
|
||||
copyOp.getDstOffset(),
|
||||
copyOp.getSrcOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status =
|
||||
rewriteSubviewCopyLikeOp(copyOp,
|
||||
copyOp.getDeviceDst(),
|
||||
copyOp.getHostSrc(),
|
||||
copyOp.getDeviceDstOffset(),
|
||||
copyOp.getHostSrcOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](MemRefType resultType,
|
||||
Value dst,
|
||||
Value src,
|
||||
int64_t dstByteOffset,
|
||||
int64_t srcByteOffset,
|
||||
int64_t sliceBytes) {
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType =
|
||||
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
53
src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp
Normal file
53
src/PIM/Pass/PimConstantFolding/PimConstantFoldingPass.cpp
Normal file
@@ -0,0 +1,53 @@
|
||||
#include "Patterns.hpp"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
|
||||
populateConstantFoldingConstantPatterns(owningPatterns);
|
||||
populateConstantFoldingSubviewPatterns(owningPatterns);
|
||||
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@@ -26,37 +27,12 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value) {
|
||||
while (true) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
if (failed(resolvedAddress))
|
||||
return false;
|
||||
}
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||
@@ -80,7 +56,7 @@ struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationP
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)))
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
@@ -139,8 +115,27 @@ private:
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
if (!isHostAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : coreOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
if (succeeded(resolveContiguousAddress(operand)))
|
||||
continue;
|
||||
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -160,10 +155,10 @@ private:
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
if (isHostAddressableValue(source))
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that still requires host-side execution");
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user