fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s

This commit is contained in:
NiccoloN
2026-05-08 14:21:45 +02:00
parent 58e6587697
commit b1272d2283
7 changed files with 541 additions and 81 deletions
+34
View File
@@ -133,6 +133,18 @@ def PimSendManyOp : PimOp<"send_many", []> {
let hasCustomAssemblyFormat = 1;
}
def PimSendTensorOp : PimOp<"send_tensor", []> {
let summary = "Send equal contiguous chunks of one tensor to target cores";
let arguments = (ins
PimTensor:$input,
DenseI32ArrayAttr:$targetCoreIds
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimSendBatchOp : PimOp<"send_batch", []> {
let summary = "Send a per-lane tensor to target cores from a batched core";
@@ -203,6 +215,28 @@ def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
let hasCustomAssemblyFormat = 1;
}
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
let summary = "Receive equal contiguous chunks from source cores into one tensor";
let arguments = (ins
PimTensor:$outputBuffer,
DenseI32ArrayAttr:$sourceCoreIds
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
let summary = "Receive per-lane tensors from source cores into a batched core";
+68 -4
View File
@@ -4,8 +4,8 @@
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -100,9 +100,9 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute("operandSegmentSizes",
builder.getDenseI32ArrayAttr(
{static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
@@ -267,6 +267,33 @@ ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
}
void PimSendTensorOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printCoreIdList(printer, "to", getTargetCoreIds());
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
}
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"targetCoreIds cannot be specified both positionally and in attr-dict");
if (!targetCoreIds.empty())
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
return parser.resolveOperand(input, inputType, result.operands);
}
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getInputs());
@@ -333,6 +360,43 @@ ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result)
return success();
}
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
printOpenDelimiter(printer, ListDelimiter::Paren);
printer.printOperand(getOutputBuffer());
printCloseDelimiter(printer, ListDelimiter::Paren);
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutputBuffer().getType());
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand outputBuffer;
Type outputBufferType;
Type outputType;
SmallVector<int32_t> sourceCoreIds;
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|| parser.parseType(outputType))
return failure();
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
return parser.emitError(parser.getCurrentLocation(),
"sourceCoreIds cannot be specified both positionally and in attr-dict");
if (!sourceCoreIds.empty())
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
printCoreIdList(printer, "from", getSourceCoreIds());
printer << " into ";
+38 -6
View File
@@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types
return op->emitError() << kind << " values must all have the same type";
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
return op->emitError() << kind << " values must all use the same shaped container kind";
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
if (firstShapedType.getElementType() != shapedType.getElementType()
|| firstShapedType.getShape() != shapedType.getShape())
return op->emitError() << kind << " values must all have the same shape and element type";
}
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
@@ -61,9 +81,7 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return coreBatchOp.getLaneCount();
}
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
ArrayRef<int32_t> coreIds,
size_t valueCount) {
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside pim.core_batch");
@@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
if (failed(verifyCompatibleShapedTypes(
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
@@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() {
return emitError("body must terminate with pim.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
if (failed(verifyCompatibleShapedTypes(
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
return emitError("body yield type must match output type");
return success();
@@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() {
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendManyBatchOp::verify() {
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
@@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() {
return success();
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveManyBatchOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
@@ -34,10 +34,8 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
.getOutput();
}
static FailureOr<Value> getBufferOrValue(RewriterBase& rewriter,
Value value,
const BufferizationOptions& options,
BufferizationState& state) {
static FailureOr<Value>
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
if (isa<BufferLikeType>(value.getType()))
return value;
return getBuffer(rewriter, value, options, state);
@@ -205,13 +203,37 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveM
resultTypes.push_back(outputBufferOpt->getType());
}
auto newOp = PimReceiveManyOp::create(
rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr());
auto newOp = PimReceiveManyOp::create(rewriter,
receiveOp.getLoc(),
TypeRange(resultTypes),
ValueRange(outputBuffers),
receiveOp.getSourceCoreIdsAttr());
rewriter.replaceOp(receiveOp, newOp.getOutputs());
return success();
}
};
struct ReceiveTensorOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveTensorOp>(op);
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
return success();
}
};
struct ReceiveManyBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<ReceiveManyBatchOpInterface, PimReceiveManyBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
@@ -337,6 +359,30 @@ struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOp
}
};
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto sendOp = cast<PimSendTensorOp>(op);
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
return success();
}
};
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
@@ -349,23 +395,26 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|| mapOp.getInputs().empty())
return {};
return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}};
return {
{&mapOp->getOpOperand(0), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType>
getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto mapOp = cast<PimMapOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|| mapOp.getInputs().empty())
return failure();
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
@@ -417,13 +466,9 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
};
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return true;
}
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
@@ -436,19 +481,18 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
return {};
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
return {
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
return false;
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType>
getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
@@ -467,13 +511,11 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
BufferizationState& state) const {
auto coreBatchOp = cast<PimCoreBatchOp>(op);
bool alreadyBufferized = llvm::all_of(coreBatchOp.getWeights(), [](Value weight) {
return isa<BufferLikeType>(weight.getType());
}) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) {
return isa<BufferLikeType>(input.getType());
}) && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
return isa<BufferLikeType>(arg.getType());
});
bool alreadyBufferized =
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
if (alreadyBufferized)
return success();
@@ -693,8 +735,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);