compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -7,23 +7,12 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||
assert(attr && "required precomputed channel attr is missing");
|
||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -2,16 +2,10 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -80,18 +69,4 @@ def spatToPimVSoftmax : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user