remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim

This commit is contained in:
NiccoloN
2026-04-14 11:55:19 +02:00
parent 368e340a40
commit 30ee9640d4
8 changed files with 158 additions and 551 deletions
@@ -1,6 +1,7 @@
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include <cstddef>
@@ -12,6 +13,16 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
@@ -54,6 +65,26 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();