standardize spatial and pim dialects
remove old unused stuff
This commit is contained in:
@@ -136,7 +136,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Two possible accepted shapes:
|
||||
@@ -157,7 +157,7 @@ LogicalResult SpatWeightedVMMOp::verify() {
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getVector().getType().getShape();
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
/* Accepted shape:
|
||||
@@ -200,59 +200,6 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatImgConcatOp::verify() {
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
size_t channelTileRest = img_c % crossbarSize;
|
||||
|
||||
auto operands = getOperands();
|
||||
|
||||
// Check number of operands
|
||||
if (img_w * img_h * channelTiles != operands.size())
|
||||
return emitError("Number of operands does not match output image size");
|
||||
|
||||
// For each output pixel, check that the inputTiles have a correct shape
|
||||
for (size_t x = 0; x < img_w; x++) {
|
||||
for (size_t y = 0; y < img_h; y++) {
|
||||
size_t channel_counts = 0;
|
||||
for (size_t t = 0; t < channelTiles; t++) {
|
||||
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
||||
if (!inputShape)
|
||||
return emitError("Invalid input type, must be ShapedType");
|
||||
|
||||
// N == W == H == 1
|
||||
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
||||
return emitError("Invalid input shape: N,W,H must all be 1");
|
||||
|
||||
size_t inputChannels = getImageChannel(inputShape);
|
||||
|
||||
// Check the number of channels in this tile are correct:
|
||||
// - CASE1: last tile of pixel, if there is some rest it must match that
|
||||
// - CASE2: common case, the channel count is exactly the crossbarSize
|
||||
if (t == channelTiles - 1 && channelTileRest != 0) {
|
||||
if (inputChannels != channelTileRest)
|
||||
return emitError("Invalid channel count for last tile of pixel");
|
||||
}
|
||||
else {
|
||||
if (inputChannels != crossbarSize)
|
||||
return emitError("Invalid channel count for some pixel tile");
|
||||
}
|
||||
|
||||
channel_counts += inputChannels;
|
||||
}
|
||||
|
||||
if (channel_counts != img_c)
|
||||
emitError("Invalid number of channels for some pixel");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
@@ -308,22 +255,6 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
||||
auto operands = getOperands();
|
||||
auto imgShape = mlir::cast<ShapedType>(getType());
|
||||
size_t img_w = getImageWidth(imgShape);
|
||||
size_t img_h = getImageHeight(imgShape);
|
||||
size_t img_c = getImageChannel(imgShape);
|
||||
|
||||
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
||||
|
||||
assert(tile < channelTiles);
|
||||
assert(x < img_w);
|
||||
assert(y < img_h);
|
||||
|
||||
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user