This commit is contained in:
@@ -5,7 +5,7 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
add_pim_library(OMONNXToSpatial
|
||||
ConversionPatterns.cpp
|
||||
CompileTime.cpp
|
||||
HostLegality.cpp
|
||||
ONNXToSpatialVerifier.cpp
|
||||
PrePatterns.cpp
|
||||
PostPatterns.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -195,66 +198,90 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
|
||||
static std::optional<CompileTimeSource>
|
||||
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||
if (!op)
|
||||
return false;
|
||||
return std::nullopt;
|
||||
|
||||
if (!visited.insert(op).second)
|
||||
return true;
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
chainLength += 1;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited);
|
||||
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
return std::nullopt;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isCompileTimeOpImpl(transposeOp, visited);
|
||||
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isCompileTimeOpImpl(collapseShapeOp,visited);
|
||||
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isCompileTimeOpImpl(expandShapeOp, visited);
|
||||
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited);
|
||||
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isCompileTimeOpImpl(splatOp, visited);
|
||||
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isCompileTimeOpImpl(extractRowsOp, visited);
|
||||
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)){
|
||||
bool res = true;
|
||||
for(auto operandValue : concatOp.getOperands()){
|
||||
res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited);
|
||||
if(!res) break;
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
std::optional<CompileTimeSource> res = {};
|
||||
for (auto operandValue : concatOp.getOperands()) {
|
||||
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
|
||||
if (!partialRes)
|
||||
return std::nullopt;
|
||||
|
||||
if (!res) {
|
||||
res = partialRes;
|
||||
continue;
|
||||
}
|
||||
if(res->chainLength < partialRes->chainLength){
|
||||
res = partialRes;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(op, visited);
|
||||
}
|
||||
|
||||
bool isCompileTimeComputable(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isCompileTimeOpImpl(definingOp, visited);
|
||||
return getCompileTimeSourceImpl(definingOp, visited).has_value();
|
||||
}
|
||||
|
||||
bool isCompileTimeOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isCompileTimeOpImpl(op, visited);
|
||||
return getCompileTimeSourceImpl(op, visited).has_value();
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
||||
|
||||
@@ -6,6 +6,13 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CompileTimeSource {
|
||||
mlir::Operation* source;
|
||||
size_t chainLength;
|
||||
};
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
|
||||
|
||||
bool isCompileTimeComputable(mlir::Value value);
|
||||
|
||||
bool isCompileTimeOp(mlir::Operation* op);
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -220,7 +220,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
|
||||
auto source = getCompileTimeSource(extractSlice.getOperation());
|
||||
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
|
||||
|
||||
diagnostics.report(extractSlice.getOperation(),
|
||||
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
checkWeightsDirectlyExtracted(funcOp, diagnostics);
|
||||
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+1
-1
@@ -5,6 +5,6 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user