add constant folding and verification pass for pim host operations

better validation scripts output
big refactors
This commit is contained in:
NiccoloN
2026-03-20 12:08:12 +01:00
parent 4e50e056e3
commit 6e1de865bb
64 changed files with 1364 additions and 2265 deletions

View File

@@ -25,7 +25,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -202,9 +202,9 @@ LogicalResult SpatVMaxOp::verify() {
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
@@ -225,10 +225,10 @@ LogicalResult SpatImgConcatOp::verify() {
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (GET_IMAGE_N(inputShape) != 1 || GET_IMAGE_WIDTH(inputShape) != 1 || GET_IMAGE_HEIGHT(inputShape) != 1)
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = GET_IMAGE_CHANNEL(inputShape);
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
@@ -311,9 +311,9 @@ LogicalResult SpatWeightedCompute::verify() {
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = GET_IMAGE_WIDTH(imgShape);
size_t img_h = GET_IMAGE_HEIGHT(imgShape);
size_t img_c = GET_IMAGE_CHANNEL(imgShape);
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());