Files
Raptor/src/PIM/Conversion/SpatialToPim/Patterns.cpp
T
NiccoloN 984f362623
Validate Operations / validate-operations (push) Waiting to run
roba
2026-06-26 13:02:38 +02:00

132 lines
5.9 KiB
C++

#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
struct LowerFragmentAssemblyReconciliatorPattern
: OpConversionPattern<spatial::SpatReconciliatorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
std::optional<StringRef> modeAttr = op.getMode();
if (!modeAttr || *modeAttr != "fragment_assembly")
return failure();
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
if (failed(validateFragmentAssemblyMetadata(
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
return failure();
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
SmallVector<int64_t, 4> fragmentOffsets;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
}
Value source = fragmentOperands[operandIndex];
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
currentOutput = tensor::InsertSliceOp::create(rewriter,
op.getLoc(),
fragment,
currentOutput,
getStaticIndexAttrs(rewriter, fragmentOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
rewriter.replaceOp(op, currentOutput);
return success();
}
};
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
}
} // namespace onnx_mlir