multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s

This commit is contained in:
NiccoloN
2026-04-23 09:28:57 +02:00
parent 0f13269040
commit 412ca957f6
16 changed files with 415 additions and 420 deletions

View File

@@ -1,5 +1,4 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -14,10 +13,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
@@ -134,7 +127,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -155,7 +148,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -200,9 +193,8 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
LogicalResult SpatCompute::verify() {
// Check that the terminator yields the same number and types as the compute results.
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -257,7 +249,7 @@ LogicalResult SpatWeightedCompute::verify() {
return success();
}
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();