@@ -34,12 +34,25 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
|
|||||||
|
|
||||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
||||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
||||||
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
||||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static mlir::Value resolveForYieldedAliasToInit(mlir::scf::ForOp forOp,
|
||||||
|
mlir::Value yieldedValue,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
yieldedValue = resolveLoopCarriedAliasImpl(yieldedValue, knowledge);
|
||||||
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
||||||
|
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
}
|
||||||
|
return yieldedValue;
|
||||||
|
}
|
||||||
|
|
||||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||||
value = resolveAlias(value, knowledge);
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
@@ -60,15 +73,8 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
|||||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
if (result) {
|
if (result) {
|
||||||
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
|
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands())
|
||||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
return resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
|
||||||
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
||||||
}
|
|
||||||
return yieldedValue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,16 +521,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
||||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value = yieldedValue;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -643,16 +640,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), nullptr);
|
||||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
||||||
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value = yieldedValue;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -862,7 +850,7 @@ llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const S
|
|||||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||||
if (failed(resolvedOffset))
|
if (failed(resolvedOffset))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
return ResolvedContiguousAddress {base, *resolvedOffset};
|
return ResolvedContiguousAddress {resolveAlias(base, &knowledge), *resolvedOffset};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1334,6 +1334,38 @@ static Value affineAddConst(
|
|||||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value affineMulConst(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
|
||||||
|
if (factor == 1)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value affineFloorDivConst(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||||
|
assert(divisor > 0 && "expected positive affine floordiv divisor");
|
||||||
|
if (divisor == 1)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value affineModConst(
|
||||||
|
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
|
||||||
|
assert(modulus > 0 && "expected positive affine mod divisor");
|
||||||
|
if (modulus == 1)
|
||||||
|
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
|
||||||
|
}
|
||||||
|
|
||||||
static Value createConvInputPatch(Value input,
|
static Value createConvInputPatch(Value input,
|
||||||
RankedTensorType patchType,
|
RankedTensorType patchType,
|
||||||
Value batchIndex,
|
Value batchIndex,
|
||||||
@@ -2414,11 +2446,6 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
|||||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||||
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
|
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
|
||||||
Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart);
|
|
||||||
Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch);
|
|
||||||
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
|
||||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight);
|
|
||||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth);
|
|
||||||
|
|
||||||
auto im2colLoop = buildNormalizedScfFor(
|
auto im2colLoop = buildNormalizedScfFor(
|
||||||
rewriter,
|
rewriter,
|
||||||
@@ -2429,13 +2456,17 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
|||||||
ValueRange {im2colInit},
|
ValueRange {im2colInit},
|
||||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||||
Value im2colAcc = iterArgs.front();
|
Value im2colAcc = iterArgs.front();
|
||||||
Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart);
|
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
|
||||||
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
|
Value batchIndex =
|
||||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
|
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
Value batchPatchIndex =
|
||||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
|
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
|
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||||
|
Value inputHeightOffset =
|
||||||
|
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
|
||||||
|
Value inputWidthOffset =
|
||||||
|
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
|
||||||
|
|
||||||
auto patchType =
|
auto patchType =
|
||||||
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
|
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
|
||||||
@@ -2844,11 +2875,9 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|||||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||||
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
|
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
|
||||||
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
|
|
||||||
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
||||||
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
|
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
|
||||||
Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth);
|
Value localHeightOffset = args.lane;
|
||||||
Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1);
|
|
||||||
Value packedRowInit =
|
Value packedRowInit =
|
||||||
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
|
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
|
||||||
auto widthLoop = buildNormalizedScfFor(
|
auto widthLoop = buildNormalizedScfFor(
|
||||||
@@ -2859,7 +2888,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|||||||
c1,
|
c1,
|
||||||
ValueRange {packedRowInit},
|
ValueRange {packedRowInit},
|
||||||
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
|
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
|
||||||
Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1);
|
Value localWidthOffset = widthIndex;
|
||||||
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
|
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
|
||||||
auto rowLoop = buildNormalizedScfFor(
|
auto rowLoop = buildNormalizedScfFor(
|
||||||
rewriter,
|
rewriter,
|
||||||
@@ -2878,7 +2907,8 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|||||||
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
|
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
|
||||||
Value rowChunk = tensor::ExpandShapeOp::create(
|
Value rowChunk = tensor::ExpandShapeOp::create(
|
||||||
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
|
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
|
||||||
Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth);
|
Value flatOffset = affineMulConst(
|
||||||
|
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
|
||||||
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
|
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
|
||||||
SmallVector<OpFoldResult> rowSizes {
|
SmallVector<OpFoldResult> rowSizes {
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
|
||||||
@@ -2905,7 +2935,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|||||||
c1,
|
c1,
|
||||||
ValueRange {zeroRow},
|
ValueRange {zeroRow},
|
||||||
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
|
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
|
||||||
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar);
|
Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
|
||||||
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
|
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
|
||||||
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
|
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
|
||||||
Value aTile = tensor::ExtractSliceOp::create(
|
Value aTile = tensor::ExtractSliceOp::create(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
@@ -10,6 +11,7 @@
|
|||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
@@ -131,46 +133,92 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
|||||||
return result.getUses().begin()->getOperandNumber();
|
return result.getUses().begin()->getOperandNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BatchFragmentAssemblyPlan {
|
static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
|
||||||
unsigned returnIndex = 0;
|
collectFragmentAssemblyCopiesFromBlueprint(spatial::SpatBlueprintOp blueprint,
|
||||||
int64_t localSourceElementOffset = 0;
|
IRMapping& mapper,
|
||||||
int64_t fragmentByteSize = 0;
|
int64_t lane,
|
||||||
SmallVector<int64_t, 8> hostOffsetsByLane;
|
unsigned hostTargetIndex,
|
||||||
};
|
Value fixedSource = {}) {
|
||||||
|
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||||
|
|
||||||
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) {
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
assert(!values.empty() && "expected lane-indexed values");
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||||
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||||
return getOrCreateIndexConstant(rewriter, anchor, values.front());
|
return blueprint.emitOpError(
|
||||||
|
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||||
|
|
||||||
if (values.size() >= 2) {
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
int64_t step = values[1] - values[0];
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
if (!sourceOffsetsAttr)
|
||||||
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||||
});
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
if (arithmetic) {
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
if (step == 0)
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
return base;
|
int64_t rank = resultType.getRank();
|
||||||
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
|
|
||||||
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult();
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult();
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
rank,
|
||||||
|
fragmentOperands.size(),
|
||||||
|
operandIndices,
|
||||||
|
sourceOffsets,
|
||||||
|
flatOffsets,
|
||||||
|
flatSizes,
|
||||||
|
flatStrides)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
Value source = fixedSource ? fixedSource : mapper.lookupOrDefault(fragmentOperands[operandIndices[fragmentIndex]]);
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (failed(forEachContiguousDestinationChunk(
|
||||||
|
resultType.getShape(),
|
||||||
|
fragmentOffsets,
|
||||||
|
fragmentSizes,
|
||||||
|
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||||
|
int64_t hostElementOffset = 0;
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||||
|
hostElementOffset += offset * hostStrides[dim];
|
||||||
|
|
||||||
|
FragmentAssemblyCopy copy;
|
||||||
|
copy.source = source;
|
||||||
|
copy.sourceType = sourceType;
|
||||||
|
copy.hostTargetIndex = hostTargetIndex;
|
||||||
|
copy.lane = lane;
|
||||||
|
copy.sourceByteOffset = (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||||
|
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||||
|
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||||
|
copies.push_back(copy);
|
||||||
|
return success();
|
||||||
|
})))
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
return copies;
|
||||||
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
|
||||||
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
|
|
||||||
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
|
|
||||||
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
|
|
||||||
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
|
|
||||||
}
|
|
||||||
return selected;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>>
|
static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
|
||||||
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
collectTopLevelFragmentAssemblyCopies(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
||||||
SmallVector<BatchFragmentAssemblyPlan, 8> plans;
|
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||||
if (!packedResultType.hasStaticShape() || laneCount == 0)
|
if (!packedResultType.hasStaticShape() || laneCount == 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -187,15 +235,14 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
|||||||
std::optional<StringRef> mode = blueprint.getMode();
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr)
|
||||||
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
|
||||||
return failure();
|
return failure();
|
||||||
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
|
||||||
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!hostResultType || !hostResultType.hasStaticShape())
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
|
if (!hostResultType || !hostResultType.hasStaticShape() || !stridesAttr)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
@@ -204,6 +251,7 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
|||||||
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||||
int64_t rank = hostResultType.getRank();
|
int64_t rank = hostResultType.getRank();
|
||||||
|
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||||
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
@@ -215,16 +263,15 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
|||||||
flatSizes,
|
flatSizes,
|
||||||
flatStrides)))
|
flatStrides)))
|
||||||
return failure();
|
return failure();
|
||||||
|
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
||||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
|
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
|
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
|
||||||
int64_t lane = sourceElementOffset / payloadElementCount;
|
int64_t lane = sourceElementOffset / payloadElementCount;
|
||||||
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
|
|
||||||
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
|
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentOffsets;
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
SmallVector<int64_t, 4> fragmentSizes;
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
@@ -236,44 +283,31 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (failed(forEachContiguousDestinationChunk(
|
if (failed(forEachContiguousDestinationChunk(
|
||||||
hostResultType.getShape(),
|
hostResultType.getShape(),
|
||||||
fragmentOffsets,
|
fragmentOffsets,
|
||||||
fragmentSizes,
|
fragmentSizes,
|
||||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||||
int64_t hostElementOffset = 0;
|
int64_t hostElementOffset = 0;
|
||||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
hostElementOffset += offset * hostStrides[dim];
|
||||||
hostElementOffset += offset * hostStrides[dim];
|
|
||||||
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
|
||||||
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
|
|
||||||
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
|
|
||||||
|
|
||||||
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) {
|
FragmentAssemblyCopy copy;
|
||||||
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset
|
copy.source = result;
|
||||||
&& plan.fragmentByteSize == fragmentByteSize;
|
copy.sourceType = packedResultType;
|
||||||
});
|
copy.hostTargetIndex = returnIndex;
|
||||||
if (planIt == plans.end()) {
|
copy.lane = lane;
|
||||||
BatchFragmentAssemblyPlan plan;
|
copy.sourceByteOffset =
|
||||||
plan.returnIndex = returnIndex;
|
((sourceElementOffset % payloadElementCount) + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||||
plan.localSourceElementOffset = chunkSourceOffset;
|
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||||
plan.fragmentByteSize = fragmentByteSize;
|
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||||
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
copies.push_back(copy);
|
||||||
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
return success();
|
||||||
plans.push_back(std::move(plan));
|
})))
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
|
||||||
return success();
|
|
||||||
})))
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const BatchFragmentAssemblyPlan& plan : plans)
|
return copies;
|
||||||
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
|
|
||||||
return failure();
|
|
||||||
return plans;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||||
@@ -284,22 +318,6 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
|
|||||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
|
||||||
SmallVector<OpFoldResult, 4> attrs;
|
|
||||||
attrs.reserve(values.size());
|
|
||||||
for (int64_t value : values)
|
|
||||||
attrs.push_back(builder.getIndexAttr(value));
|
|
||||||
return attrs;
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
|
||||||
SmallVector<OpFoldResult, 4> strides;
|
|
||||||
strides.reserve(rank);
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim)
|
|
||||||
strides.push_back(builder.getIndexAttr(1));
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
Location loc,
|
Location loc,
|
||||||
ShapedType destinationType,
|
ShapedType destinationType,
|
||||||
@@ -351,123 +369,6 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
mapper);
|
mapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
ArrayRef<OpFoldResult> baseOffsets,
|
|
||||||
ArrayRef<int64_t> fragmentOffsets,
|
|
||||||
IRMapping& mapper) {
|
|
||||||
SmallVector<OpFoldResult, 4> combined;
|
|
||||||
combined.reserve(fragmentOffsets.size());
|
|
||||||
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
|
|
||||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
|
||||||
int64_t base = cast<IntegerAttr>(attr).getInt();
|
|
||||||
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
|
|
||||||
if (fragmentOffsets[dim] == 0) {
|
|
||||||
combined.push_back(dynamicBase);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value staticOffset =
|
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
|
|
||||||
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
|
|
||||||
}
|
|
||||||
return combined;
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
|
||||||
spatial::SpatBlueprintOp blueprint,
|
|
||||||
Value hostTarget,
|
|
||||||
ArrayRef<OpFoldResult> baseOffsets,
|
|
||||||
IRMapping& mapper) {
|
|
||||||
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
|
||||||
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
|
||||||
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
|
||||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
|
||||||
|
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
|
||||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
|
||||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
|
||||||
return blueprint.emitOpError(
|
|
||||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
|
||||||
if (!sourceOffsetsAttr)
|
|
||||||
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
|
||||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
|
||||||
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
|
||||||
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
|
||||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
|
||||||
int64_t rank = resultType.getRank();
|
|
||||||
|
|
||||||
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
|
||||||
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
|
||||||
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
|
||||||
rank,
|
|
||||||
fragmentOperands.size(),
|
|
||||||
operandIndices,
|
|
||||||
sourceOffsets,
|
|
||||||
flatOffsets,
|
|
||||||
flatSizes,
|
|
||||||
flatStrides)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
|
||||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentOffsets;
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
||||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
|
||||||
if (flatStrides[flatIndex] != 1)
|
|
||||||
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
|
||||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
|
||||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
|
||||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentShape;
|
|
||||||
fragmentShape.reserve(rank);
|
|
||||||
for (int64_t dim = 0; dim < rank; ++dim)
|
|
||||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
|
||||||
|
|
||||||
Value fragment = source;
|
|
||||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
|
||||||
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
|
||||||
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
|
||||||
if (failed(extractOffsets))
|
|
||||||
return failure();
|
|
||||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
|
||||||
blueprint.getLoc(),
|
|
||||||
source,
|
|
||||||
getStaticIndexAttrs(rewriter, *extractOffsets),
|
|
||||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
|
||||||
getUnitStrides(rewriter, rank));
|
|
||||||
}
|
|
||||||
|
|
||||||
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
|
||||||
blueprint.getLoc(),
|
|
||||||
fragment,
|
|
||||||
hostTarget,
|
|
||||||
buildFragmentOffsets(rewriter,
|
|
||||||
blueprint.getLoc(),
|
|
||||||
baseOffsets,
|
|
||||||
fragmentOffsets,
|
|
||||||
mapper),
|
|
||||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
|
||||||
getUnitStrides(rewriter, rank))
|
|
||||||
.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostTarget;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
@@ -505,10 +406,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
||||||
|
|
||||||
SmallVector<unsigned> returnOperandIndices;
|
SmallVector<unsigned> returnOperandIndices;
|
||||||
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult;
|
SmallVector<SmallVector<FragmentAssemblyCopyRun, 1>, 4> fragmentAssemblyRunsByResult;
|
||||||
if (computeBatchOp.getNumResults() != 0) {
|
if (computeBatchOp.getNumResults() != 0) {
|
||||||
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||||
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults());
|
fragmentAssemblyRunsByResult.resize(computeBatchOp.getNumResults());
|
||||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||||
if (result.use_empty())
|
if (result.use_empty())
|
||||||
continue;
|
continue;
|
||||||
@@ -522,12 +423,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return computeBatchOp.emitOpError(
|
return computeBatchOp.emitOpError(
|
||||||
"resultful compute_batch publication lowering requires static ranked tensor results");
|
"resultful compute_batch publication lowering requires static ranked tensor results");
|
||||||
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans =
|
FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
|
||||||
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
collectTopLevelFragmentAssemblyCopies(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
||||||
if (failed(fragmentAssemblyPlans))
|
if (failed(fragmentAssemblyCopies))
|
||||||
return computeBatchOp.emitOpError(
|
return computeBatchOp.emitOpError("failed to collect top-level fragment assembly copies for compute_batch result");
|
||||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
|
||||||
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end());
|
groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, computeBatchOp.getLaneCount());
|
||||||
|
if (failed(fragmentAssemblyRuns))
|
||||||
|
return computeBatchOp.emitOpError("failed to group top-level fragment assembly copies into regular runs");
|
||||||
|
fragmentAssemblyRunsByResult[resultIndex].assign(fragmentAssemblyRuns->begin(), fragmentAssemblyRuns->end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -614,8 +518,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
if (resultIndex >= returnOperandIndices.size())
|
if (resultIndex >= returnOperandIndices.size())
|
||||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||||
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
|
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
|
||||||
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size()
|
bool hasFragmentAssembly = resultIndex < fragmentAssemblyRunsByResult.size()
|
||||||
&& !fragmentAssemblyPlansByResult[resultIndex].empty();
|
&& !fragmentAssemblyRunsByResult[resultIndex].empty();
|
||||||
if (!hasDirectReturn && !hasFragmentAssembly)
|
if (!hasDirectReturn && !hasFragmentAssembly)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
@@ -626,27 +530,23 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
|
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
|
||||||
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
|
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
|
||||||
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
|
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
|
||||||
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) {
|
DenseMap<unsigned, Value> updatedOutputs;
|
||||||
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc());
|
for (const FragmentAssemblyCopyRun& run : fragmentAssemblyRunsByResult[resultIndex]) {
|
||||||
auto sizeAttr = pim::getCheckedI32Attr(
|
Value outputTensor = updatedOutputs.lookup(run.hostTargetIndex);
|
||||||
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size");
|
if (!outputTensor)
|
||||||
if (failed(sizeAttr))
|
outputTensor = outputTensors[run.hostTargetIndex](rewriter, insertSlice.getLoc());
|
||||||
|
FragmentAssemblyCopyRun mappedRun = run;
|
||||||
|
mappedRun.source = mappedSource;
|
||||||
|
FailureOr<Value> updated =
|
||||||
|
emitFragmentAssemblyCopyRuns(rewriter,
|
||||||
|
insertSlice.getLoc(),
|
||||||
|
ArrayRef<FragmentAssemblyCopyRun> {mappedRun},
|
||||||
|
outputTensor,
|
||||||
|
coreBatchOp.getOperation(),
|
||||||
|
laneArg);
|
||||||
|
if (failed(updated))
|
||||||
return failure();
|
return failure();
|
||||||
Value hostTargetOffset =
|
updatedOutputs[run.hostTargetIndex] = *updated;
|
||||||
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
|
|
||||||
Value deviceSourceOffset = getOrCreateIndexConstant(
|
|
||||||
rewriter, coreBatchOp.getOperation(),
|
|
||||||
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
|
|
||||||
outputTensor =
|
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
|
||||||
insertSlice.getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
hostTargetOffset,
|
|
||||||
deviceSourceOffset,
|
|
||||||
outputTensor,
|
|
||||||
mappedSource,
|
|
||||||
*sizeAttr)
|
|
||||||
.getOutput();
|
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -657,11 +557,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
|
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
|
||||||
std::optional<StringRef> modeAttr = blueprint.getMode();
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
|
||||||
blueprint,
|
collectFragmentAssemblyCopiesFromBlueprint(blueprint, mapper, /*lane=*/0, /*hostTargetIndex=*/0);
|
||||||
hostTarget,
|
if (failed(fragmentAssemblyCopies))
|
||||||
insertSlice.getMixedOffsets(),
|
return failure();
|
||||||
mapper);
|
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
|
||||||
|
groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, /*laneCount=*/1);
|
||||||
|
if (failed(fragmentAssemblyRuns))
|
||||||
|
return failure();
|
||||||
|
SmallVector<int64_t> zeroOffsets(hostTargetType.getRank(), 0);
|
||||||
|
Value baseHostOffset = createHostTargetOffset(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
hostTargetType,
|
||||||
|
insertSlice.getMixedOffsets(),
|
||||||
|
zeroOffsets,
|
||||||
|
mapper);
|
||||||
|
FailureOr<Value> updatedHostTarget = emitFragmentAssemblyCopyRuns(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
*fragmentAssemblyRuns,
|
||||||
|
hostTarget,
|
||||||
|
coreBatchOp.getOperation(),
|
||||||
|
std::nullopt,
|
||||||
|
baseHostOffset);
|
||||||
if (failed(updatedHostTarget))
|
if (failed(updatedHostTarget))
|
||||||
return failure();
|
return failure();
|
||||||
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -186,4 +191,375 @@ forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
|||||||
return visit(visit, 0);
|
return visit(visit, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static mlir::Value
|
||||||
|
createSteppedOffset(OpBuilder& builder, Location loc, mlir::Value start, mlir::Value index, int64_t stepBytes) {
|
||||||
|
if (stepBytes == 0)
|
||||||
|
return start;
|
||||||
|
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, stepBytes);
|
||||||
|
mlir::Value scaled = arith::MulIOp::create(builder, loc, index, step).getResult();
|
||||||
|
return arith::AddIOp::create(builder, loc, start, scaled).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static mlir::Value createIndexedOffset(OpBuilder& builder,
|
||||||
|
Location loc,
|
||||||
|
mlir::Value indexArg,
|
||||||
|
ArrayRef<int64_t> values) {
|
||||||
|
assert(!values.empty() && "expected lane-indexed values");
|
||||||
|
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
||||||
|
return arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||||
|
|
||||||
|
if (values.size() >= 2) {
|
||||||
|
int64_t step = values[1] - values[0];
|
||||||
|
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
||||||
|
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
||||||
|
});
|
||||||
|
if (arithmetic) {
|
||||||
|
mlir::Value base = arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||||
|
mlir::Value stepValue = arith::ConstantIndexOp::create(builder, loc, step);
|
||||||
|
mlir::Value scaledIndex = arith::MulIOp::create(builder, loc, indexArg, stepValue).getResult();
|
||||||
|
return arith::AddIOp::create(builder, loc, base, scaledIndex).getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value selected = arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||||
|
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
||||||
|
mlir::Value indexValue = arith::ConstantIndexOp::create(builder, loc, static_cast<int64_t>(lane + 1));
|
||||||
|
mlir::Value cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, indexArg, indexValue);
|
||||||
|
mlir::Value candidate = arith::ConstantIndexOp::create(builder, loc, value);
|
||||||
|
selected = arith::SelectOp::create(builder, loc, cmp, candidate, selected);
|
||||||
|
}
|
||||||
|
return selected;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FragmentAssemblyCopyRunFamily {
|
||||||
|
FragmentAssemblyCopyRun prototype;
|
||||||
|
SmallVector<int64_t, 8> sourceRunStartDeltas;
|
||||||
|
SmallVector<int64_t, 8> hostRunStartDeltas;
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool computeUniformRunStartDelta(ArrayRef<int64_t> prototypeStarts,
|
||||||
|
ArrayRef<int64_t> runStarts,
|
||||||
|
int64_t& delta) {
|
||||||
|
if (prototypeStarts.size() != runStarts.size() || prototypeStarts.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
delta = runStarts.front() - prototypeStarts.front();
|
||||||
|
return llvm::all_of(llvm::zip_equal(prototypeStarts, runStarts), [&](auto pair) {
|
||||||
|
auto [prototypeStart, runStart] = pair;
|
||||||
|
return runStart - prototypeStart == delta;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool canMergeFragmentAssemblyCopyRunIntoFamily(const FragmentAssemblyCopyRunFamily& family,
|
||||||
|
const FragmentAssemblyCopyRun& run,
|
||||||
|
int64_t& sourceRunStartDelta,
|
||||||
|
int64_t& hostRunStartDelta) {
|
||||||
|
const FragmentAssemblyCopyRun& prototype = family.prototype;
|
||||||
|
if (prototype.source != run.source || prototype.sourceType != run.sourceType
|
||||||
|
|| prototype.hostTargetIndex != run.hostTargetIndex || prototype.count != run.count
|
||||||
|
|| prototype.sourceStepBytes != run.sourceStepBytes || prototype.hostStepBytes != run.hostStepBytes
|
||||||
|
|| prototype.byteSize != run.byteSize)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!computeUniformRunStartDelta(prototype.sourceStartBytesByLane, run.sourceStartBytesByLane, sourceRunStartDelta))
|
||||||
|
return false;
|
||||||
|
return computeUniformRunStartDelta(prototype.hostStartBytesByLane, run.hostStartBytesByLane, hostRunStartDelta);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<FragmentAssemblyCopyRunFamily, 8>
|
||||||
|
groupFragmentAssemblyCopyRunFamilies(ArrayRef<FragmentAssemblyCopyRun> runs) {
|
||||||
|
auto compareRunStarts = [](ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
|
||||||
|
return std::lexicographical_compare(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<FragmentAssemblyCopyRun, 8> sortedRuns(runs.begin(), runs.end());
|
||||||
|
llvm::sort(sortedRuns, [&](const FragmentAssemblyCopyRun& lhs, const FragmentAssemblyCopyRun& rhs) {
|
||||||
|
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
|
||||||
|
return lhs.hostTargetIndex < rhs.hostTargetIndex;
|
||||||
|
if (lhs.source != rhs.source)
|
||||||
|
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
|
||||||
|
if (lhs.byteSize != rhs.byteSize)
|
||||||
|
return lhs.byteSize < rhs.byteSize;
|
||||||
|
if (lhs.count != rhs.count)
|
||||||
|
return lhs.count < rhs.count;
|
||||||
|
if (lhs.sourceStepBytes != rhs.sourceStepBytes)
|
||||||
|
return lhs.sourceStepBytes < rhs.sourceStepBytes;
|
||||||
|
if (lhs.hostStepBytes != rhs.hostStepBytes)
|
||||||
|
return lhs.hostStepBytes < rhs.hostStepBytes;
|
||||||
|
if (compareRunStarts(lhs.sourceStartBytesByLane, rhs.sourceStartBytesByLane))
|
||||||
|
return true;
|
||||||
|
if (compareRunStarts(rhs.sourceStartBytesByLane, lhs.sourceStartBytesByLane))
|
||||||
|
return false;
|
||||||
|
return compareRunStarts(lhs.hostStartBytesByLane, rhs.hostStartBytesByLane);
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<FragmentAssemblyCopyRunFamily, 8> families;
|
||||||
|
for (const FragmentAssemblyCopyRun& run : sortedRuns) {
|
||||||
|
int64_t sourceRunStartDelta = 0;
|
||||||
|
int64_t hostRunStartDelta = 0;
|
||||||
|
if (!families.empty()
|
||||||
|
&& canMergeFragmentAssemblyCopyRunIntoFamily(
|
||||||
|
families.back(), run, sourceRunStartDelta, hostRunStartDelta)) {
|
||||||
|
families.back().sourceRunStartDeltas.push_back(sourceRunStartDelta);
|
||||||
|
families.back().hostRunStartDeltas.push_back(hostRunStartDelta);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
FragmentAssemblyCopyRunFamily family;
|
||||||
|
family.prototype = run;
|
||||||
|
family.sourceRunStartDeltas.push_back(0);
|
||||||
|
family.hostRunStartDeltas.push_back(0);
|
||||||
|
families.push_back(std::move(family));
|
||||||
|
}
|
||||||
|
|
||||||
|
return families;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>>
|
||||||
|
groupFragmentAssemblyCopyRuns(ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount) {
|
||||||
|
if (laneCount == 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
struct LaneLocalCopyRun {
|
||||||
|
FragmentAssemblyCopyRun run;
|
||||||
|
int64_t lane = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<FragmentAssemblyCopy, 8> sortedCopies(copies.begin(), copies.end());
|
||||||
|
llvm::sort(sortedCopies, [](const FragmentAssemblyCopy& lhs, const FragmentAssemblyCopy& rhs) {
|
||||||
|
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
|
||||||
|
return lhs.hostTargetIndex < rhs.hostTargetIndex;
|
||||||
|
if (lhs.source != rhs.source)
|
||||||
|
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
|
||||||
|
if (lhs.lane != rhs.lane)
|
||||||
|
return lhs.lane < rhs.lane;
|
||||||
|
if (lhs.byteSize != rhs.byteSize)
|
||||||
|
return lhs.byteSize < rhs.byteSize;
|
||||||
|
if (lhs.sourceByteOffset != rhs.sourceByteOffset)
|
||||||
|
return lhs.sourceByteOffset < rhs.sourceByteOffset;
|
||||||
|
return lhs.hostByteOffset < rhs.hostByteOffset;
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<LaneLocalCopyRun, 8> laneRuns;
|
||||||
|
for (const FragmentAssemblyCopy& copy : sortedCopies) {
|
||||||
|
if (copy.lane < 0 || copy.lane >= static_cast<int64_t>(laneCount))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!laneRuns.empty()) {
|
||||||
|
LaneLocalCopyRun& laneRun = laneRuns.back();
|
||||||
|
FragmentAssemblyCopyRun& run = laneRun.run;
|
||||||
|
if (run.source == copy.source && run.sourceType == copy.sourceType
|
||||||
|
&& run.hostTargetIndex == copy.hostTargetIndex && laneRun.lane == copy.lane && run.byteSize == copy.byteSize
|
||||||
|
&& run.sourceStartBytesByLane.size() == 1 && run.hostStartBytesByLane.size() == 1) {
|
||||||
|
int64_t previousSourceOffset = run.sourceStartBytesByLane.front() + (run.count - 1) * run.sourceStepBytes;
|
||||||
|
int64_t previousHostOffset = run.hostStartBytesByLane.front() + (run.count - 1) * run.hostStepBytes;
|
||||||
|
int64_t sourceDelta = copy.sourceByteOffset - previousSourceOffset;
|
||||||
|
int64_t hostDelta = copy.hostByteOffset - previousHostOffset;
|
||||||
|
if (run.count == 1) {
|
||||||
|
run.sourceStepBytes = sourceDelta;
|
||||||
|
run.hostStepBytes = hostDelta;
|
||||||
|
++run.count;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (run.sourceStepBytes == sourceDelta && run.hostStepBytes == hostDelta) {
|
||||||
|
++run.count;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LaneLocalCopyRun laneRun;
|
||||||
|
laneRun.run.source = copy.source;
|
||||||
|
laneRun.run.sourceType = copy.sourceType;
|
||||||
|
laneRun.run.hostTargetIndex = copy.hostTargetIndex;
|
||||||
|
laneRun.run.count = 1;
|
||||||
|
laneRun.run.byteSize = copy.byteSize;
|
||||||
|
laneRun.run.sourceStartBytesByLane.push_back(copy.sourceByteOffset);
|
||||||
|
laneRun.run.hostStartBytesByLane.push_back(copy.hostByteOffset);
|
||||||
|
laneRun.lane = copy.lane;
|
||||||
|
laneRuns.push_back(std::move(laneRun));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<FragmentAssemblyCopyRun, 8> mergedRuns;
|
||||||
|
for (const LaneLocalCopyRun& laneRun : laneRuns) {
|
||||||
|
size_t laneIndex = static_cast<size_t>(laneRun.lane);
|
||||||
|
auto mergedIt = llvm::find_if(mergedRuns, [&](const FragmentAssemblyCopyRun& run) {
|
||||||
|
return run.source == laneRun.run.source && run.sourceType == laneRun.run.sourceType
|
||||||
|
&& run.hostTargetIndex == laneRun.run.hostTargetIndex && run.count == laneRun.run.count
|
||||||
|
&& run.byteSize == laneRun.run.byteSize && run.sourceStepBytes == laneRun.run.sourceStepBytes
|
||||||
|
&& run.hostStepBytes == laneRun.run.hostStepBytes && laneIndex < run.sourceStartBytesByLane.size()
|
||||||
|
&& run.sourceStartBytesByLane[laneIndex] == std::numeric_limits<int64_t>::min();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (mergedIt == mergedRuns.end()) {
|
||||||
|
FragmentAssemblyCopyRun merged = laneRun.run;
|
||||||
|
merged.sourceStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||||
|
merged.hostStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||||
|
merged.sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
|
||||||
|
merged.hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
|
||||||
|
mergedRuns.push_back(std::move(merged));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
mergedIt->sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
|
||||||
|
mergedIt->hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const FragmentAssemblyCopyRun& run : mergedRuns) {
|
||||||
|
if (llvm::any_of(run.sourceStartBytesByLane,
|
||||||
|
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(run.hostStartBytesByLane,
|
||||||
|
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergedRuns;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRun(OpBuilder& builder,
|
||||||
|
Location loc,
|
||||||
|
const FragmentAssemblyCopyRun& run,
|
||||||
|
mlir::Value hostTarget,
|
||||||
|
Operation* anchor,
|
||||||
|
std::optional<mlir::Value> laneArg,
|
||||||
|
mlir::Value baseHostOffset,
|
||||||
|
mlir::Value sourceRunStartDelta = {},
|
||||||
|
mlir::Value hostRunStartDelta = {}) {
|
||||||
|
auto sizeAttr = pim::getCheckedI32Attr(builder, anchor, run.byteSize, "fragment assembly host copy byte size");
|
||||||
|
if (failed(sizeAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
mlir::Value hostStart;
|
||||||
|
mlir::Value sourceStart;
|
||||||
|
if (laneArg) {
|
||||||
|
hostStart = createIndexedOffset(builder, loc, *laneArg, run.hostStartBytesByLane);
|
||||||
|
sourceStart = createIndexedOffset(builder, loc, *laneArg, run.sourceStartBytesByLane);
|
||||||
|
} else {
|
||||||
|
hostStart = arith::ConstantIndexOp::create(builder, loc, run.hostStartBytesByLane.front());
|
||||||
|
sourceStart = arith::ConstantIndexOp::create(builder, loc, run.sourceStartBytesByLane.front());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hostRunStartDelta)
|
||||||
|
hostStart = arith::AddIOp::create(builder, loc, hostStart, hostRunStartDelta).getResult();
|
||||||
|
if (sourceRunStartDelta)
|
||||||
|
sourceStart = arith::AddIOp::create(builder, loc, sourceStart, sourceRunStartDelta).getResult();
|
||||||
|
if (baseHostOffset)
|
||||||
|
hostStart = arith::AddIOp::create(builder, loc, baseHostOffset, hostStart).getResult();
|
||||||
|
|
||||||
|
if (run.count == 1) {
|
||||||
|
return pim::PimMemCopyDevToHostOp::create(builder,
|
||||||
|
loc,
|
||||||
|
hostTarget.getType(),
|
||||||
|
hostStart,
|
||||||
|
sourceStart,
|
||||||
|
hostTarget,
|
||||||
|
run.source,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||||
|
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, run.count);
|
||||||
|
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
|
||||||
|
FailureOr<NormalizedLoopResult> loop = buildNormalizedScfFor(
|
||||||
|
builder,
|
||||||
|
loc,
|
||||||
|
lowerBound,
|
||||||
|
upperBound,
|
||||||
|
step,
|
||||||
|
ValueRange {hostTarget},
|
||||||
|
[&](OpBuilder& loopBuilder,
|
||||||
|
Location bodyLoc,
|
||||||
|
mlir::Value flatIndex,
|
||||||
|
ValueRange iterArgs,
|
||||||
|
SmallVectorImpl<mlir::Value>& yielded) {
|
||||||
|
mlir::Value hostOffset = createSteppedOffset(loopBuilder, bodyLoc, hostStart, flatIndex, run.hostStepBytes);
|
||||||
|
mlir::Value sourceOffset =
|
||||||
|
createSteppedOffset(loopBuilder, bodyLoc, sourceStart, flatIndex, run.sourceStepBytes);
|
||||||
|
mlir::Value copied =
|
||||||
|
pim::PimMemCopyDevToHostOp::create(loopBuilder,
|
||||||
|
bodyLoc,
|
||||||
|
iterArgs.front().getType(),
|
||||||
|
hostOffset,
|
||||||
|
sourceOffset,
|
||||||
|
iterArgs.front(),
|
||||||
|
run.source,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
yielded.push_back(copied);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
if (failed(loop))
|
||||||
|
return failure();
|
||||||
|
return loop->results.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRunFamily(OpBuilder& builder,
|
||||||
|
Location loc,
|
||||||
|
const FragmentAssemblyCopyRunFamily& family,
|
||||||
|
mlir::Value hostTarget,
|
||||||
|
Operation* anchor,
|
||||||
|
std::optional<mlir::Value> laneArg,
|
||||||
|
mlir::Value baseHostOffset) {
|
||||||
|
if (family.sourceRunStartDeltas.size() == 1)
|
||||||
|
return emitFragmentAssemblyCopyRun(
|
||||||
|
builder, loc, family.prototype, hostTarget, anchor, laneArg, baseHostOffset);
|
||||||
|
|
||||||
|
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||||
|
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, family.sourceRunStartDeltas.size());
|
||||||
|
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
|
||||||
|
FailureOr<NormalizedLoopResult> outerLoop = buildNormalizedScfFor(
|
||||||
|
builder,
|
||||||
|
loc,
|
||||||
|
lowerBound,
|
||||||
|
upperBound,
|
||||||
|
step,
|
||||||
|
ValueRange {hostTarget},
|
||||||
|
[&](OpBuilder& loopBuilder,
|
||||||
|
Location bodyLoc,
|
||||||
|
mlir::Value runIndex,
|
||||||
|
ValueRange iterArgs,
|
||||||
|
SmallVectorImpl<mlir::Value>& yielded) {
|
||||||
|
mlir::Value sourceRunStartDelta =
|
||||||
|
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.sourceRunStartDeltas);
|
||||||
|
mlir::Value hostRunStartDelta =
|
||||||
|
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.hostRunStartDeltas);
|
||||||
|
FailureOr<mlir::Value> copied = emitFragmentAssemblyCopyRun(loopBuilder,
|
||||||
|
bodyLoc,
|
||||||
|
family.prototype,
|
||||||
|
iterArgs.front(),
|
||||||
|
anchor,
|
||||||
|
laneArg,
|
||||||
|
baseHostOffset,
|
||||||
|
sourceRunStartDelta,
|
||||||
|
hostRunStartDelta);
|
||||||
|
if (failed(copied))
|
||||||
|
return failure();
|
||||||
|
yielded.push_back(*copied);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
if (failed(outerLoop))
|
||||||
|
return failure();
|
||||||
|
return outerLoop->results.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ArrayRef<FragmentAssemblyCopyRun> runs,
|
||||||
|
mlir::Value hostTarget,
|
||||||
|
Operation* anchor,
|
||||||
|
std::optional<mlir::Value> laneArg,
|
||||||
|
mlir::Value baseHostOffset) {
|
||||||
|
for (const FragmentAssemblyCopyRunFamily& family : groupFragmentAssemblyCopyRunFamilies(runs)) {
|
||||||
|
FailureOr<mlir::Value> updatedHostTarget =
|
||||||
|
emitFragmentAssemblyCopyRunFamily(rewriter, loc, family, hostTarget, anchor, laneArg, baseHostOffset);
|
||||||
|
if (failed(updatedHostTarget))
|
||||||
|
return failure();
|
||||||
|
hostTarget = *updatedHostTarget;
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostTarget;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
@@ -59,6 +65,39 @@ forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
|
|||||||
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
|
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
|
||||||
callback);
|
callback);
|
||||||
|
|
||||||
|
struct FragmentAssemblyCopy {
|
||||||
|
mlir::Value source;
|
||||||
|
mlir::RankedTensorType sourceType;
|
||||||
|
unsigned hostTargetIndex = 0;
|
||||||
|
int64_t lane = 0;
|
||||||
|
int64_t sourceByteOffset = 0;
|
||||||
|
int64_t hostByteOffset = 0;
|
||||||
|
int64_t byteSize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FragmentAssemblyCopyRun {
|
||||||
|
mlir::Value source;
|
||||||
|
mlir::RankedTensorType sourceType;
|
||||||
|
unsigned hostTargetIndex = 0;
|
||||||
|
int64_t count = 0;
|
||||||
|
int64_t sourceStepBytes = 0;
|
||||||
|
int64_t hostStepBytes = 0;
|
||||||
|
int64_t byteSize = 0;
|
||||||
|
mlir::SmallVector<int64_t, 8> sourceStartBytesByLane;
|
||||||
|
mlir::SmallVector<int64_t, 8> hostStartBytesByLane;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::SmallVector<FragmentAssemblyCopyRun, 8>>
|
||||||
|
groupFragmentAssemblyCopyRuns(llvm::ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount = 1);
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(mlir::IRRewriter& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
llvm::ArrayRef<FragmentAssemblyCopyRun> runs,
|
||||||
|
mlir::Value hostTarget,
|
||||||
|
mlir::Operation* anchor,
|
||||||
|
std::optional<mlir::Value> laneArg = std::nullopt,
|
||||||
|
mlir::Value baseHostOffset = {});
|
||||||
|
|
||||||
inline mlir::tensor::EmptyOp
|
inline mlir::tensor::EmptyOp
|
||||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
@@ -8,6 +9,7 @@
|
|||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
@@ -30,22 +32,9 @@ static bool isChannelUseChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
|
||||||
Location loc,
|
|
||||||
ShapedType destinationType,
|
|
||||||
ArrayRef<int64_t> fragmentOffsets) {
|
|
||||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
|
||||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
|
||||||
|
|
||||||
int64_t byteOffset = 0;
|
|
||||||
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
|
|
||||||
byteOffset += offset * strides[dim] * elementBytes;
|
|
||||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
||||||
spatial::SpatBlueprintOp blueprint,
|
spatial::SpatBlueprintOp blueprint,
|
||||||
IRMapping& mapping) {
|
IRMapping& mapping) {
|
||||||
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
|
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
@@ -77,56 +66,54 @@ static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
|||||||
flatStrides)))
|
flatStrides)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
|
||||||
|
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentOffsets;
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
int64_t fragmentElements = 1;
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
if (flatStrides[flatIndex] != 1)
|
if (flatStrides[flatIndex] != 1)
|
||||||
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
fragmentElements *= flatSizes[flatIndex];
|
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||||
|
if (failed(forEachContiguousDestinationChunk(
|
||||||
|
resultType.getShape(),
|
||||||
|
fragmentOffsets,
|
||||||
|
fragmentSizes,
|
||||||
|
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||||
|
int64_t hostElementOffset = 0;
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||||
|
hostElementOffset += offset * hostStrides[dim];
|
||||||
|
|
||||||
int64_t fragmentBytes =
|
FragmentAssemblyCopy copy;
|
||||||
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
copy.source = source;
|
||||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
copy.sourceType = sourceType;
|
||||||
blueprint.getOperation(),
|
copy.sourceByteOffset =
|
||||||
fragmentBytes,
|
(sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||||
"fragment assembly host copy size");
|
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||||
if (failed(sizeAttr))
|
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||||
|
copies.push_back(copy);
|
||||||
|
return success();
|
||||||
|
})))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
|
|
||||||
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
|
||||||
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
|
||||||
blueprint,
|
|
||||||
"fragment assembly device source offset");
|
|
||||||
if (failed(deviceSourceOffsetBytes))
|
|
||||||
return failure();
|
|
||||||
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
|
||||||
rewriter.getInsertionBlock()->getParentOp(),
|
|
||||||
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
|
||||||
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
|
||||||
blueprint.getLoc(),
|
|
||||||
currentOutput.getType(),
|
|
||||||
hostTargetOffset,
|
|
||||||
deviceSourceOffset,
|
|
||||||
currentOutput,
|
|
||||||
source,
|
|
||||||
*sizeAttr)
|
|
||||||
.getOutput();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return currentOutput;
|
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
||||||
|
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
|
||||||
|
if (failed(runs))
|
||||||
|
return failure();
|
||||||
|
return emitFragmentAssemblyCopyRuns(
|
||||||
|
rewriter, blueprint.getLoc(), *runs, currentOutput, blueprint.getOperation());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
@@ -11,6 +12,7 @@
|
|||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
@@ -638,6 +640,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
flatSizes,
|
flatSizes,
|
||||||
flatStrides)))
|
flatStrides)))
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
|
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
|
||||||
continue;
|
continue;
|
||||||
@@ -675,29 +678,27 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
failedChunk = true;
|
failedChunk = true;
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto sizeAttr =
|
FragmentAssemblyCopy copy;
|
||||||
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size");
|
copy.source = storedValue;
|
||||||
if (failed(sizeAttr)) {
|
copy.sourceType = sourceType;
|
||||||
failedChunk = true;
|
copy.hostByteOffset = *hostOffset;
|
||||||
return failure();
|
copy.sourceByteOffset = *sourceOffset;
|
||||||
}
|
copy.byteSize = *fragmentBytes;
|
||||||
|
copies.push_back(copy);
|
||||||
outputTensor =
|
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
|
||||||
blueprint.getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
|
||||||
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
|
||||||
outputTensor,
|
|
||||||
storedValue,
|
|
||||||
*sizeAttr)
|
|
||||||
.getOutput();
|
|
||||||
return success();
|
return success();
|
||||||
})))
|
})))
|
||||||
failedChunk = true;
|
failedChunk = true;
|
||||||
if (failedChunk)
|
if (failedChunk)
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
|
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
|
||||||
|
if (failed(runs))
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
FailureOr<Value> updatedOutput =
|
||||||
|
emitFragmentAssemblyCopyRuns(rewriter, blueprint.getLoc(), *runs, outputTensor, producerOp);
|
||||||
|
if (failed(updatedOutput))
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
outputTensor = *updatedOutput;
|
||||||
markOpToRemove(blueprint.getOperation());
|
markOpToRemove(blueprint.getOperation());
|
||||||
}
|
}
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
|||||||
+686
-140
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user