This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -11,6 +15,107 @@ namespace raptor {
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
strides.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
return strides;
|
||||
}
|
||||
|
||||
struct LowerFragmentAssemblyReconciliatorPattern
|
||||
: OpConversionPattern<spatial::SpatReconciliatorOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
std::optional<StringRef> modeAttr = op.getMode();
|
||||
if (!modeAttr || *modeAttr != "fragment_assembly")
|
||||
return failure();
|
||||
|
||||
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
||||
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
||||
|
||||
Value currentOutput =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return op.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = fragmentOperands[operandIndex];
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
|
||||
currentOutput = tensor::InsertSliceOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
fragment,
|
||||
currentOutput,
|
||||
getStaticIndexAttrs(rewriter, fragmentOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, currentOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user