Refactor ONNXToSpatial Common and diagnostics

This commit is contained in:
NiccoloN
2026-05-04 13:42:43 +02:00
parent b6ba1e4fea
commit f789954ad7
26 changed files with 686 additions and 486 deletions
@@ -12,14 +12,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <iterator>
#include <utility>
#include "Common.hpp"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -32,8 +31,6 @@ using namespace mlir;
namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
@@ -50,7 +47,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
};
@@ -186,7 +183,10 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure();
return;
}
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
signalPassFailure();
@@ -287,64 +287,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
return false;
}
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
bool sourceOpernadHasWeightAlways(Operation* op) {
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
if (op == nullptr)
return false;
@@ -416,30 +359,32 @@ bool sourceOpernadHasWeightAlways(Operation* op) {
return res;
}
else {
op->dump();
llvm_unreachable("Global instruction not handle in func");
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
return failure();
}
}
while (source == nullptr);
if (hasWeightAlways(source))
return true;
return false;
return hasWeightAlways(source);
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
instruction)
|| isa<func::ReturnOp>(instruction)
|| sourceOpernadHasWeightAlways(&instruction))
|| isa<func::ReturnOp>(instruction))
continue;
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
if (failed(weightBacked))
return failure();
if (*weightBacked)
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
@@ -456,6 +401,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
return success();
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {