Add DCP alghoritm, partial working test
This commit is contained in:
@@ -204,45 +204,47 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
auto& block = getBody().front();
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("ComputeOp must have a single yield operation");
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("ComputeOp must have a single yield operation");
|
||||
|
||||
auto resultTypes = getResultTypes();
|
||||
auto yieldTypes = yieldOp->getOperandTypes();
|
||||
if (resultTypes.size() != yieldTypes.size()) {
|
||||
return emitError("ComputeOp must have same number of results as yieldOp "
|
||||
"operands");
|
||||
}
|
||||
|
||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||
auto resultType = std::get<0>(it);
|
||||
auto yieldType = std::get<1>(it);
|
||||
|
||||
// Same type and compatible shape
|
||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
|
||||
return emitError("ComputeOp output must be of the same type as yieldOp "
|
||||
"operand");
|
||||
auto resultTypes = getResultTypes();
|
||||
auto yieldTypes = yieldOp->getOperandTypes();
|
||||
if (resultTypes.size() != yieldTypes.size()) {
|
||||
return emitError("ComputeOp must have same number of results as yieldOp "
|
||||
"operands");
|
||||
}
|
||||
|
||||
// Same encoding
|
||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
|
||||
return emitError("ComputeOp output must have the same encoding as "
|
||||
"yieldOp operand");
|
||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||
auto resultType = std::get<0>(it);
|
||||
auto yieldType = std::get<1>(it);
|
||||
|
||||
// Same type and compatible shape
|
||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
|
||||
return emitError("ComputeOp output must be of the same type as yieldOp "
|
||||
"operand");
|
||||
}
|
||||
|
||||
// Same encoding
|
||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
|
||||
return emitError("ComputeOp output must have the same encoding as "
|
||||
"yieldOp operand");
|
||||
}
|
||||
}
|
||||
else {
|
||||
return emitError("ComputeOp output has an encoding while yieldOp "
|
||||
"operand does not have one");
|
||||
}
|
||||
}
|
||||
else {
|
||||
return emitError("ComputeOp output has an encoding while yieldOp "
|
||||
"operand does not have one");
|
||||
}
|
||||
}
|
||||
else {
|
||||
// If result does not have an encoding, yield shouldn't either
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
return emitError("ComputeOp output must not have an encoding if "
|
||||
"yieldOp operand has one");
|
||||
// If result does not have an encoding, yield shouldn't either
|
||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||
return emitError("ComputeOp output must not have an encoding if "
|
||||
"yieldOp operand has one");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,6 +257,27 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
|
||||
if (!yieldOp)
|
||||
return failure();
|
||||
|
||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArg.getOwner() == &block) {
|
||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push_back(yieldedValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user