big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
-230
View File
@@ -147,69 +147,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimMapOp::print(OpAsmPrinter& printer) {
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInputs().front().getType());
printer << " -> ";
printer.printType(getOutputs().front().getType());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
Type inputType;
Type outputType;
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (inputs.empty())
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
SmallVector<Type> inputTypes(inputs.size(), inputType);
SmallVector<Type> outputTypes(inputs.size(), outputType);
if (regionArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
applyArgumentTypes(inputTypes, regionArgs);
Region* body = result.addRegion();
return parser.parseRegion(*body, regionArgs);
}
void PimEmptyManyOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getOutputs().front().getType());
printer << " x" << getOutputs().size();
}
ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
int64_t resultCount = 0;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)
|| parser.parseKeyword("x") || parser.parseInteger(resultCount))
return failure();
if (resultCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive");
SmallVector<Type> resultTypes(resultCount, outputType);
result.addTypes(resultTypes);
return success();
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -237,36 +174,6 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
@@ -294,72 +201,6 @@ ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result)
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, TypeRange(getInputs()));
}
ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> inputTypes;
SmallVector<int32_t> targetCoreIds;
if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
return failure();
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimReceiveManyOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
@@ -434,77 +275,6 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result
return success();
}
void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
SmallVector<Type> outputTypes;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimExtractRowsOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printCompressedValueSequence(printer, getOutputBuffers());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printer.printType(getInput().getType());
printer << " -> ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
SmallVector<OpAsmParser::UnresolvedOperand> outputBuffers;
Type inputType;
SmallVector<Type> outputTypes;
if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen()
|| parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen()
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
return failure();
if (outputBuffers.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match");
if (parser.resolveOperand(input, inputType, result.operands)
|| parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());