multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -14,10 +13,7 @@
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
}
|
||||
|
||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
||||
if (wcomputeOp)
|
||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
|
||||
|
||||
if (coreOp)
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
return failure();
|
||||
@@ -134,7 +127,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -155,7 +148,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -200,9 +193,8 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
LogicalResult SpatCompute::verify() {
|
||||
// Check that the terminator yields the same number and types as the compute results.
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -257,7 +249,7 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user