standardize spatial and pim dialects

remove old unused stuff
This commit is contained in:
NiccoloN
2026-03-23 21:21:31 +01:00
parent 0478d979ff
commit 93e20c1dfc
18 changed files with 693 additions and 1519 deletions

View File

@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir