132 lines
5.9 KiB
C++
132 lines
5.9 KiB
C++
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace raptor {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
} // namespace raptor
|
|
|
|
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
|
SmallVector<OpFoldResult, 4> attrs;
|
|
attrs.reserve(values.size());
|
|
for (int64_t value : values)
|
|
attrs.push_back(builder.getIndexAttr(value));
|
|
return attrs;
|
|
}
|
|
|
|
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
|
SmallVector<OpFoldResult, 4> strides;
|
|
strides.reserve(rank);
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
strides.push_back(builder.getIndexAttr(1));
|
|
return strides;
|
|
}
|
|
|
|
struct LowerFragmentAssemblyReconciliatorPattern
|
|
: OpConversionPattern<spatial::SpatReconciliatorOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
std::optional<StringRef> modeAttr = op.getMode();
|
|
if (!modeAttr || *modeAttr != "fragment_assembly")
|
|
return failure();
|
|
|
|
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
|
|
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
|
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
|
|
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
|
|
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
|
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
|
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
|
int64_t rank = resultType.getRank();
|
|
|
|
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
|
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
|
if (failed(validateFragmentAssemblyMetadata(
|
|
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
|
|
return failure();
|
|
|
|
Value currentOutput =
|
|
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
|
|
|
SmallVector<int64_t, 4> fragmentOffsets;
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
|
if (flatStrides[flatIndex] != 1)
|
|
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
|
}
|
|
|
|
Value source = fragmentOperands[operandIndex];
|
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
|
if (!sourceType || !sourceType.hasStaticShape())
|
|
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
|
|
|
SmallVector<int64_t, 4> fragmentShape;
|
|
fragmentShape.reserve(rank);
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
|
|
|
Value fragment = source;
|
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
|
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
|
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
|
if (failed(extractOffsets))
|
|
return failure();
|
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
|
op.getLoc(),
|
|
source,
|
|
getStaticIndexAttrs(rewriter, *extractOffsets),
|
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
|
getUnitStrides(rewriter, rank));
|
|
}
|
|
|
|
currentOutput = tensor::InsertSliceOp::create(rewriter,
|
|
op.getLoc(),
|
|
fragment,
|
|
currentOutput,
|
|
getStaticIndexAttrs(rewriter, fragmentOffsets),
|
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
|
getUnitStrides(rewriter, rank))
|
|
.getResult();
|
|
}
|
|
|
|
rewriter.replaceOp(op, currentOutput);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateInitialPatterns(RewritePatternSet& patterns) {
|
|
raptor::populateWithGenerated(patterns);
|
|
populateTransposeLoweringPatterns(patterns);
|
|
}
|
|
|
|
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
|
raptor::populateWithGenerated(patterns);
|
|
populateTransposeLoweringPatterns(patterns);
|
|
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|