remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim

This commit is contained in:
NiccoloN
2026-04-14 11:55:19 +02:00
parent 368e340a40
commit 30ee9640d4
8 changed files with 158 additions and 551 deletions

View File

@@ -3,7 +3,6 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
DCPGraph/Graph.cpp
DCPGraph/Task.cpp

View File

@@ -1,532 +0,0 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace spatial {
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
auto resultShape = cast<ShapedType>(resultType);
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref
return memref::AllocOp::create(rewriter, loc, memrefResultType);
}
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
++usersIterator;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
++usersIterator;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
return failure();
}
Operation* otherUser = nullptr;
if (firstUser == op)
otherUser = secondUser;
else if (secondUser == op)
otherUser = firstUser;
else {
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
return failure();
}
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
return failure();
}
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
return failure();
}
return otherUser;
}
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
// Receive/Send operation. However, during bufferization, the first of the
// Receive/Send operation that is processed gets removed. As such, we need to
// "precompute" the coreId needed for the other op, and save it as attribute
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
// Save the coreId for this op into the other op as attribute
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
}
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
// Input tensor to the compute OP are always read into its local memory
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensor to the compute OP are _never_ written into its local memory
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the compute OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Bufferize its block
auto& block = op->getRegion(0).front();
return bufferizeBlockSignature(&block, rewriter, options, state);
}
};
/*
* This can be used for operation that have a single argument, which is a
* variadic of tensors, and a single output with the same same shape
* Example: VAdd, VSub, VExp
*/
template <typename InterfaceName, typename OpTy, typename ToTy>
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor values into memref values
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands;
memrefOperands.reserve(op->getNumOperands());
for (auto operand : op->getOperands()) {
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
}
// TODO: Support addiction with more than 2 operands
if (memrefOperands.size() > 2) {
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
"with 1 or 2 operands, for now.");
return failure();
}
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
memrefOperands.push_back(outputTensor);
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
template <typename InterfaceName, typename OpTy, typename ToTy>
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor value into memref value
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt))
return failure();
auto memrefOperand = *memrefOperandOpt;
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue = ToTy::create(rewriter,
op->getLoc(),
outputTensor.getType(),
cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.recv
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
if (failed(srcCoreId))
return failure();
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
if (failed(dstCoreId))
return failure();
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
op,
srcMemRef,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(dstCoreId.value()));
return success();
}
};
struct ChannelBroadcastReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId) {
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
cast<IntegerAttr>(precomputedOtherCoreId))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
return failure();
}
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
for (Operation* user : channelNewOp->getUsers()) {
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
if (!sendOp)
continue;
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
}
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
return failure();
}();
if (failed(srcCoreId))
return failure();
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelBroadcastSendOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the broadcast send into one pim.send per broadcast receiver.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
return failure();
}
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
rewriter.setInsertionPoint(op);
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
pim::PimSendOp::create(rewriter,
op->getLoc(),
srcMemRef,
rewriter.getI32IntegerAttr(sizeInBytes),
rewriter.getI32IntegerAttr(dstCoreId));
}
if (!foundReceiver) {
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
return failure();
}
rewriter.eraseOp(op);
return success();
}
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
});
}
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
});
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir