This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include <limits>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
@@ -138,26 +139,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
|
||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
strides.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
return strides;
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
Location loc,
|
||||
ShapedType destinationType,
|
||||
ArrayRef<OpFoldResult> mixedOffsets,
|
||||
ArrayRef<int64_t> additionalOffsets,
|
||||
IRMapping& mapper) {
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||
|
||||
Value totalOffset;
|
||||
Location loc = insertSlice.getLoc();
|
||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
||||
for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) {
|
||||
int64_t scale = strides[dim] * elementBytes;
|
||||
Value scaledOffset;
|
||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(intAttr && "expected integer offset attribute");
|
||||
scaledOffset =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
|
||||
}
|
||||
else {
|
||||
scaledOffset = getOrCreateIndexConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
(intAttr.getInt() + additionalOffsets[dim]) * scale);
|
||||
} else {
|
||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||
if (additionalOffsets[dim] != 0) {
|
||||
Value staticOffset = getOrCreateIndexConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
additionalOffsets[dim] * scale);
|
||||
scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult();
|
||||
}
|
||||
}
|
||||
|
||||
totalOffset =
|
||||
@@ -169,6 +193,127 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
return totalOffset;
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
ShapedType destinationType,
|
||||
IRMapping& mapper) {
|
||||
SmallVector<int64_t> zeroOffsets(destinationType.getRank(), 0);
|
||||
return createHostTargetOffset(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
destinationType,
|
||||
insertSlice.getMixedOffsets(),
|
||||
zeroOffsets,
|
||||
mapper);
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ArrayRef<OpFoldResult> baseOffsets,
|
||||
ArrayRef<int64_t> fragmentOffsets,
|
||||
IRMapping& mapper) {
|
||||
SmallVector<OpFoldResult, 4> combined;
|
||||
combined.reserve(fragmentOffsets.size());
|
||||
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
int64_t base = cast<IntegerAttr>(attr).getInt();
|
||||
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
|
||||
if (fragmentOffsets[dim] == 0) {
|
||||
combined.push_back(dynamicBase);
|
||||
continue;
|
||||
}
|
||||
|
||||
Value staticOffset =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
|
||||
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
|
||||
}
|
||||
return combined;
|
||||
}
|
||||
|
||||
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||
spatial::SpatReconciliatorOp reconciliator,
|
||||
Value hostTarget,
|
||||
ArrayRef<OpFoldResult> baseOffsets,
|
||||
IRMapping& mapper) {
|
||||
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
return reconciliator.emitOpError(
|
||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
|
||||
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
fragment,
|
||||
hostTarget,
|
||||
buildFragmentOffsets(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
baseOffsets,
|
||||
fragmentOffsets,
|
||||
mapper),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
return hostTarget;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
@@ -207,8 +352,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
@@ -271,6 +418,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||
for (Operation* user : reconciliator.getOutput().getUsers()) {
|
||||
if (!isa<tensor::ParallelInsertSliceOp>(user))
|
||||
return reconciliator.emitOpError(
|
||||
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
@@ -287,10 +446,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
|
||||
continue;
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
if (auto reconciliator =
|
||||
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
|
||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
||||
reconciliator,
|
||||
hostTarget,
|
||||
insertSlice.getMixedOffsets(),
|
||||
mapper);
|
||||
if (failed(updatedHostTarget))
|
||||
return failure();
|
||||
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
||||
|
||||
@@ -30,6 +30,91 @@ static bool isChannelUseChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ShapedType destinationType,
|
||||
ArrayRef<int64_t> fragmentOffsets) {
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||
|
||||
int64_t byteOffset = 0;
|
||||
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
|
||||
byteOffset += offset * strides[dim] * elementBytes;
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
||||
}
|
||||
|
||||
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
||||
spatial::SpatReconciliatorOp reconciliator,
|
||||
IRMapping& mapping) {
|
||||
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||
|
||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
||||
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
|
||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t fragmentBytes =
|
||||
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
||||
reconciliator.getOperation(),
|
||||
fragmentBytes,
|
||||
"fragment assembly host copy size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
|
||||
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
packedFragmentOrdinal * fragmentBytes);
|
||||
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
currentOutput.getType(),
|
||||
hostTargetOffset,
|
||||
deviceSourceOffset,
|
||||
currentOutput,
|
||||
source,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
return currentOutput;
|
||||
}
|
||||
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
@@ -131,6 +216,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule
|
||||
mapping.map(*weightArg, weight);
|
||||
}
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
|
||||
if (failed(lowered))
|
||||
return false;
|
||||
mapping.map(reconciliator.getOutput(), *lowered);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -11,6 +15,107 @@ namespace raptor {
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
strides.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
return strides;
|
||||
}
|
||||
|
||||
struct LowerFragmentAssemblyReconciliatorPattern
|
||||
: OpConversionPattern<spatial::SpatReconciliatorOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
std::optional<StringRef> modeAttr = op.getMode();
|
||||
if (!modeAttr || *modeAttr != "fragment_assembly")
|
||||
return failure();
|
||||
|
||||
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
||||
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
||||
|
||||
Value currentOutput =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return op.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = fragmentOperands[operandIndex];
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
|
||||
currentOutput = tensor::InsertSliceOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
fragment,
|
||||
currentOutput,
|
||||
getStaticIndexAttrs(rewriter, fragmentOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, currentOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user