cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-25 20:58:51 +02:00
parent bdc4ca33f3
commit 0f240af271
15 changed files with 3 additions and 1182 deletions
-102
View File
@@ -102,42 +102,6 @@ def PimSendOp : PimOp<"send", []> {
}];
}
def PimSendTensorOp : PimOp<"send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor to target cores";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
DenseI32ArrayAttr:$targetCoreIds
);
let hasCustomAssemblyFormat = 1;
}
def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> {
let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let summary = "Receive a tensor from another core";
@@ -162,72 +126,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
}
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks from source cores into one tensor";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasCustomAssemblyFormat = 1;
}
def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
-226
View File
@@ -28,34 +28,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
@@ -98,12 +70,6 @@ static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<i
printCompressedIntegerList(printer, coreIds);
}
static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keyword, SmallVectorImpl<int32_t>& coreIds) {
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
return parseCompressedIntegerList(parser, coreIds);
}
} // namespace
void PimCoreOp::print(OpAsmPrinter& printer) {
@@ -295,198 +261,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());
-75
View File
@@ -90,56 +90,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static LogicalResult
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return op->emitError() << kind << " must be nested inside pim.core_batch";
int32_t laneCount = coreBatchOp.getLaneCount();
if (laneCount <= 0)
return op->emitError() << kind << " requires a positive parent laneCount";
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
@@ -177,31 +127,6 @@ LogicalResult PimCoreBatchOp::verify() {
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendTensorBatchOp::verify() {
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveTensorBatchOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorBatchCommunication(
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimVMMOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
@@ -157,72 +157,6 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
}
};
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveTensorBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorBatchOpInterface, PimReceiveTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorBatchOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorBatchOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInterface, PimConcatOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -252,30 +186,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
}
};
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface, PimSendOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -303,58 +213,6 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
}
};
struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel<SendBatchOpInterface, PimSendBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendBatchOp>(rewriter,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct SendTensorBatchOpInterface
: BufferizableOpInterface::ExternalModel<SendTensorBatchOpInterface, PimSendTensorBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorBatchOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorBatchOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface, PimCoreOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -699,13 +557,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimCoreOp::attachInterface<CoreOpInterface>(*ctx);
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveTensorBatchOp::attachInterface<ReceiveTensorBatchOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimSendBatchOp::attachInterface<SendBatchOpInterface>(*ctx);
PimSendTensorBatchOp::attachInterface<SendTensorBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
-105
View File
@@ -194,111 +194,6 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
}];
}
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
@@ -95,13 +95,6 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
return shapedType.getShape();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto batchOp = op->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return failure();
return batchOp.getLaneCount();
}
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
@@ -233,68 +226,6 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
return success();
}
static LogicalResult verifyTensorChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelCount == 0)
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success();
}
static LogicalResult
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelCount != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyTensorBatchChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
@@ -564,52 +495,6 @@ LogicalResult SpatCompute::verify() {
return success();
}
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor_batch");
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
if (count <= 0)
@@ -79,8 +79,6 @@ struct MergeIrCounts {
uint64_t topLevelComputeBatchCount = 0;
uint64_t scalarChannelSendCount = 0;
uint64_t scalarChannelReceiveCount = 0;
uint64_t tensorChannelSendCount = 0;
uint64_t tensorChannelReceiveCount = 0;
uint64_t wvmmCount = 0;
uint64_t vaddCount = 0;
uint64_t scfForCount = 0;
@@ -95,10 +93,6 @@ MergeIrCounts collectMergeIrCounts(func::FuncOp funcOp) {
++counts.scalarChannelSendCount;
else if (isa<spatial::SpatChannelReceiveOp>(nestedOp))
++counts.scalarChannelReceiveCount;
else if (isa<spatial::SpatChannelSendTensorOp, spatial::SpatChannelSendTensorBatchOp>(nestedOp))
++counts.tensorChannelSendCount;
else if (isa<spatial::SpatChannelReceiveTensorOp, spatial::SpatChannelReceiveTensorBatchOp>(nestedOp))
++counts.tensorChannelReceiveCount;
else if (isa<spatial::SpatVMMOp>(nestedOp))
++counts.wvmmCount;
else if (isa<spatial::SpatVAddOp>(nestedOp))
@@ -130,9 +124,8 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {