@@ -34,12 +34,25 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
template <typename... Args>
|
||||
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
static mlir::Value resolveForYieldedAliasToInit(mlir::scf::ForOp forOp,
|
||||
mlir::Value yieldedValue,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
yieldedValue = resolveLoopCarriedAliasImpl(yieldedValue, knowledge);
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
||||
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
}
|
||||
return yieldedValue;
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
@@ -60,15 +73,8 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (result) {
|
||||
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
|
||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
||||
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
}
|
||||
return yieldedValue;
|
||||
}
|
||||
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands())
|
||||
return resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,16 +521,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -643,16 +640,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), nullptr);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -862,7 +850,7 @@ llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const S
|
||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
return ResolvedContiguousAddress {base, *resolvedOffset};
|
||||
return ResolvedContiguousAddress {resolveAlias(base, &knowledge), *resolvedOffset};
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1334,6 +1334,38 @@ static Value affineAddConst(
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineMulConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
|
||||
if (factor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineFloorDivConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(divisor > 0 && "expected positive affine floordiv divisor");
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineModConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
|
||||
assert(modulus > 0 && "expected positive affine mod divisor");
|
||||
if (modulus == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value createConvInputPatch(Value input,
|
||||
RankedTensorType patchType,
|
||||
Value batchIndex,
|
||||
@@ -2414,11 +2446,6 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
|
||||
Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart);
|
||||
Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch);
|
||||
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight);
|
||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth);
|
||||
|
||||
auto im2colLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
@@ -2429,13 +2456,17 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
||||
ValueRange {im2colInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value im2colAcc = iterArgs.front();
|
||||
Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart);
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
|
||||
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
|
||||
Value batchIndex =
|
||||
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
Value batchPatchIndex =
|
||||
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value inputHeightOffset =
|
||||
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
|
||||
Value inputWidthOffset =
|
||||
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
|
||||
|
||||
auto patchType =
|
||||
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
|
||||
@@ -2844,11 +2875,9 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
|
||||
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
|
||||
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
|
||||
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
|
||||
Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth);
|
||||
Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1);
|
||||
Value localHeightOffset = args.lane;
|
||||
Value packedRowInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
|
||||
auto widthLoop = buildNormalizedScfFor(
|
||||
@@ -2859,7 +2888,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
c1,
|
||||
ValueRange {packedRowInit},
|
||||
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
|
||||
Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1);
|
||||
Value localWidthOffset = widthIndex;
|
||||
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
|
||||
auto rowLoop = buildNormalizedScfFor(
|
||||
rewriter,
|
||||
@@ -2878,7 +2907,8 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
|
||||
Value rowChunk = tensor::ExpandShapeOp::create(
|
||||
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
|
||||
Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth);
|
||||
Value flatOffset = affineMulConst(
|
||||
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
|
||||
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
|
||||
SmallVector<OpFoldResult> rowSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
|
||||
@@ -2905,7 +2935,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
c1,
|
||||
ValueRange {zeroRow},
|
||||
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
|
||||
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar);
|
||||
Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
|
||||
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
|
||||
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
|
||||
Value aTile = tensor::ExtractSliceOp::create(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
@@ -10,6 +11,7 @@
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
@@ -131,46 +133,92 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
return result.getUses().begin()->getOperandNumber();
|
||||
}
|
||||
|
||||
struct BatchFragmentAssemblyPlan {
|
||||
unsigned returnIndex = 0;
|
||||
int64_t localSourceElementOffset = 0;
|
||||
int64_t fragmentByteSize = 0;
|
||||
SmallVector<int64_t, 8> hostOffsetsByLane;
|
||||
};
|
||||
static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
|
||||
collectFragmentAssemblyCopiesFromBlueprint(spatial::SpatBlueprintOp blueprint,
|
||||
IRMapping& mapper,
|
||||
int64_t lane,
|
||||
unsigned hostTargetIndex,
|
||||
Value fixedSource = {}) {
|
||||
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||
|
||||
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) {
|
||||
assert(!values.empty() && "expected lane-indexed values");
|
||||
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
||||
return getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
return blueprint.emitOpError(
|
||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||
|
||||
if (values.size() >= 2) {
|
||||
int64_t step = values[1] - values[0];
|
||||
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
||||
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
||||
});
|
||||
if (arithmetic) {
|
||||
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
if (step == 0)
|
||||
return base;
|
||||
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
|
||||
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult();
|
||||
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult();
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||
if (!sourceOffsetsAttr)
|
||||
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||
rank,
|
||||
fragmentOperands.size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
Value source = fixedSource ? fixedSource : mapper.lookupOrDefault(fragmentOperands[operandIndices[fragmentIndex]]);
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||
}
|
||||
|
||||
if (failed(forEachContiguousDestinationChunk(
|
||||
resultType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
int64_t hostElementOffset = 0;
|
||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||
hostElementOffset += offset * hostStrides[dim];
|
||||
|
||||
FragmentAssemblyCopy copy;
|
||||
copy.source = source;
|
||||
copy.sourceType = sourceType;
|
||||
copy.hostTargetIndex = hostTargetIndex;
|
||||
copy.lane = lane;
|
||||
copy.sourceByteOffset = (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||
copies.push_back(copy);
|
||||
return success();
|
||||
})))
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
||||
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
|
||||
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
|
||||
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
|
||||
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
|
||||
}
|
||||
return selected;
|
||||
return copies;
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>>
|
||||
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
||||
SmallVector<BatchFragmentAssemblyPlan, 8> plans;
|
||||
static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
|
||||
collectTopLevelFragmentAssemblyCopies(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
||||
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||
if (!packedResultType.hasStaticShape() || laneCount == 0)
|
||||
return failure();
|
||||
|
||||
@@ -187,15 +235,14 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
||||
std::optional<StringRef> mode = blueprint.getMode();
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
||||
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr)
|
||||
return failure();
|
||||
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||
return failure();
|
||||
|
||||
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||
if (!hostResultType || !hostResultType.hasStaticShape())
|
||||
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||
if (!hostResultType || !hostResultType.hasStaticShape() || !stridesAttr)
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
@@ -204,6 +251,7 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
||||
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||
int64_t rank = hostResultType.getRank();
|
||||
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||
@@ -215,16 +263,15 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
|
||||
int64_t lane = sourceElementOffset / payloadElementCount;
|
||||
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
|
||||
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
@@ -236,44 +283,31 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
||||
}
|
||||
|
||||
if (failed(forEachContiguousDestinationChunk(
|
||||
hostResultType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
int64_t hostElementOffset = 0;
|
||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||
hostElementOffset += offset * hostStrides[dim];
|
||||
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
|
||||
hostResultType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
int64_t hostElementOffset = 0;
|
||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||
hostElementOffset += offset * hostStrides[dim];
|
||||
|
||||
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) {
|
||||
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset
|
||||
&& plan.fragmentByteSize == fragmentByteSize;
|
||||
});
|
||||
if (planIt == plans.end()) {
|
||||
BatchFragmentAssemblyPlan plan;
|
||||
plan.returnIndex = returnIndex;
|
||||
plan.localSourceElementOffset = chunkSourceOffset;
|
||||
plan.fragmentByteSize = fragmentByteSize;
|
||||
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||
plans.push_back(std::move(plan));
|
||||
return success();
|
||||
}
|
||||
|
||||
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||
return success();
|
||||
})))
|
||||
FragmentAssemblyCopy copy;
|
||||
copy.source = result;
|
||||
copy.sourceType = packedResultType;
|
||||
copy.hostTargetIndex = returnIndex;
|
||||
copy.lane = lane;
|
||||
copy.sourceByteOffset =
|
||||
((sourceElementOffset % payloadElementCount) + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||
copies.push_back(copy);
|
||||
return success();
|
||||
})))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (const BatchFragmentAssemblyPlan& plan : plans)
|
||||
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
|
||||
return failure();
|
||||
return plans;
|
||||
return copies;
|
||||
}
|
||||
|
||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||
@@ -284,22 +318,6 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
|
||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||
SmallVector<OpFoldResult, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t value : values)
|
||||
attrs.push_back(builder.getIndexAttr(value));
|
||||
return attrs;
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||
SmallVector<OpFoldResult, 4> strides;
|
||||
strides.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
strides.push_back(builder.getIndexAttr(1));
|
||||
return strides;
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ShapedType destinationType,
|
||||
@@ -351,123 +369,6 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
mapper);
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ArrayRef<OpFoldResult> baseOffsets,
|
||||
ArrayRef<int64_t> fragmentOffsets,
|
||||
IRMapping& mapper) {
|
||||
SmallVector<OpFoldResult, 4> combined;
|
||||
combined.reserve(fragmentOffsets.size());
|
||||
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
int64_t base = cast<IntegerAttr>(attr).getInt();
|
||||
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
|
||||
if (fragmentOffsets[dim] == 0) {
|
||||
combined.push_back(dynamicBase);
|
||||
continue;
|
||||
}
|
||||
|
||||
Value staticOffset =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
|
||||
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
|
||||
}
|
||||
return combined;
|
||||
}
|
||||
|
||||
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||
spatial::SpatBlueprintOp blueprint,
|
||||
Value hostTarget,
|
||||
ArrayRef<OpFoldResult> baseOffsets,
|
||||
IRMapping& mapper) {
|
||||
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
return blueprint.emitOpError(
|
||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||
if (!sourceOffsetsAttr)
|
||||
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
|
||||
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||
rank,
|
||||
fragmentOperands.size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
}
|
||||
|
||||
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||
if (failed(extractOffsets))
|
||||
return failure();
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
blueprint.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
|
||||
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
||||
blueprint.getLoc(),
|
||||
fragment,
|
||||
hostTarget,
|
||||
buildFragmentOffsets(rewriter,
|
||||
blueprint.getLoc(),
|
||||
baseOffsets,
|
||||
fragmentOffsets,
|
||||
mapper),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
return hostTarget;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
@@ -505,10 +406,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult;
|
||||
SmallVector<SmallVector<FragmentAssemblyCopyRun, 1>, 4> fragmentAssemblyRunsByResult;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults());
|
||||
fragmentAssemblyRunsByResult.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
@@ -522,12 +423,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch publication lowering requires static ranked tensor results");
|
||||
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans =
|
||||
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
||||
if (failed(fragmentAssemblyPlans))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end());
|
||||
FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
|
||||
collectTopLevelFragmentAssemblyCopies(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
||||
if (failed(fragmentAssemblyCopies))
|
||||
return computeBatchOp.emitOpError("failed to collect top-level fragment assembly copies for compute_batch result");
|
||||
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
|
||||
groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, computeBatchOp.getLaneCount());
|
||||
if (failed(fragmentAssemblyRuns))
|
||||
return computeBatchOp.emitOpError("failed to group top-level fragment assembly copies into regular runs");
|
||||
fragmentAssemblyRunsByResult[resultIndex].assign(fragmentAssemblyRuns->begin(), fragmentAssemblyRuns->end());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,8 +518,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
|
||||
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size()
|
||||
&& !fragmentAssemblyPlansByResult[resultIndex].empty();
|
||||
bool hasFragmentAssembly = resultIndex < fragmentAssemblyRunsByResult.size()
|
||||
&& !fragmentAssemblyRunsByResult[resultIndex].empty();
|
||||
if (!hasDirectReturn && !hasFragmentAssembly)
|
||||
continue;
|
||||
|
||||
@@ -626,27 +530,23 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
|
||||
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
|
||||
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
|
||||
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) {
|
||||
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc());
|
||||
auto sizeAttr = pim::getCheckedI32Attr(
|
||||
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
DenseMap<unsigned, Value> updatedOutputs;
|
||||
for (const FragmentAssemblyCopyRun& run : fragmentAssemblyRunsByResult[resultIndex]) {
|
||||
Value outputTensor = updatedOutputs.lookup(run.hostTargetIndex);
|
||||
if (!outputTensor)
|
||||
outputTensor = outputTensors[run.hostTargetIndex](rewriter, insertSlice.getLoc());
|
||||
FragmentAssemblyCopyRun mappedRun = run;
|
||||
mappedRun.source = mappedSource;
|
||||
FailureOr<Value> updated =
|
||||
emitFragmentAssemblyCopyRuns(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
ArrayRef<FragmentAssemblyCopyRun> {mappedRun},
|
||||
outputTensor,
|
||||
coreBatchOp.getOperation(),
|
||||
laneArg);
|
||||
if (failed(updated))
|
||||
return failure();
|
||||
Value hostTargetOffset =
|
||||
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
|
||||
Value deviceSourceOffset = getOrCreateIndexConstant(
|
||||
rewriter, coreBatchOp.getOperation(),
|
||||
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
|
||||
outputTensor =
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
outputTensor.getType(),
|
||||
hostTargetOffset,
|
||||
deviceSourceOffset,
|
||||
outputTensor,
|
||||
mappedSource,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
updatedOutputs[run.hostTargetIndex] = *updated;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -657,11 +557,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
|
||||
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
||||
blueprint,
|
||||
hostTarget,
|
||||
insertSlice.getMixedOffsets(),
|
||||
mapper);
|
||||
FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
|
||||
collectFragmentAssemblyCopiesFromBlueprint(blueprint, mapper, /*lane=*/0, /*hostTargetIndex=*/0);
|
||||
if (failed(fragmentAssemblyCopies))
|
||||
return failure();
|
||||
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
|
||||
groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, /*laneCount=*/1);
|
||||
if (failed(fragmentAssemblyRuns))
|
||||
return failure();
|
||||
SmallVector<int64_t> zeroOffsets(hostTargetType.getRank(), 0);
|
||||
Value baseHostOffset = createHostTargetOffset(rewriter,
|
||||
blueprint.getLoc(),
|
||||
hostTargetType,
|
||||
insertSlice.getMixedOffsets(),
|
||||
zeroOffsets,
|
||||
mapper);
|
||||
FailureOr<Value> updatedHostTarget = emitFragmentAssemblyCopyRuns(rewriter,
|
||||
blueprint.getLoc(),
|
||||
*fragmentAssemblyRuns,
|
||||
hostTarget,
|
||||
coreBatchOp.getOperation(),
|
||||
std::nullopt,
|
||||
baseHostOffset);
|
||||
if (failed(updatedHostTarget))
|
||||
return failure();
|
||||
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
@@ -186,4 +191,375 @@ forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
||||
return visit(visit, 0);
|
||||
}
|
||||
|
||||
static mlir::Value
|
||||
createSteppedOffset(OpBuilder& builder, Location loc, mlir::Value start, mlir::Value index, int64_t stepBytes) {
|
||||
if (stepBytes == 0)
|
||||
return start;
|
||||
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, stepBytes);
|
||||
mlir::Value scaled = arith::MulIOp::create(builder, loc, index, step).getResult();
|
||||
return arith::AddIOp::create(builder, loc, start, scaled).getResult();
|
||||
}
|
||||
|
||||
static mlir::Value createIndexedOffset(OpBuilder& builder,
|
||||
Location loc,
|
||||
mlir::Value indexArg,
|
||||
ArrayRef<int64_t> values) {
|
||||
assert(!values.empty() && "expected lane-indexed values");
|
||||
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
||||
return arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||
|
||||
if (values.size() >= 2) {
|
||||
int64_t step = values[1] - values[0];
|
||||
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
||||
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
||||
});
|
||||
if (arithmetic) {
|
||||
mlir::Value base = arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||
mlir::Value stepValue = arith::ConstantIndexOp::create(builder, loc, step);
|
||||
mlir::Value scaledIndex = arith::MulIOp::create(builder, loc, indexArg, stepValue).getResult();
|
||||
return arith::AddIOp::create(builder, loc, base, scaledIndex).getResult();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Value selected = arith::ConstantIndexOp::create(builder, loc, values.front());
|
||||
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
||||
mlir::Value indexValue = arith::ConstantIndexOp::create(builder, loc, static_cast<int64_t>(lane + 1));
|
||||
mlir::Value cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, indexArg, indexValue);
|
||||
mlir::Value candidate = arith::ConstantIndexOp::create(builder, loc, value);
|
||||
selected = arith::SelectOp::create(builder, loc, cmp, candidate, selected);
|
||||
}
|
||||
return selected;
|
||||
}
|
||||
|
||||
struct FragmentAssemblyCopyRunFamily {
|
||||
FragmentAssemblyCopyRun prototype;
|
||||
SmallVector<int64_t, 8> sourceRunStartDeltas;
|
||||
SmallVector<int64_t, 8> hostRunStartDeltas;
|
||||
};
|
||||
|
||||
static bool computeUniformRunStartDelta(ArrayRef<int64_t> prototypeStarts,
|
||||
ArrayRef<int64_t> runStarts,
|
||||
int64_t& delta) {
|
||||
if (prototypeStarts.size() != runStarts.size() || prototypeStarts.empty())
|
||||
return false;
|
||||
|
||||
delta = runStarts.front() - prototypeStarts.front();
|
||||
return llvm::all_of(llvm::zip_equal(prototypeStarts, runStarts), [&](auto pair) {
|
||||
auto [prototypeStart, runStart] = pair;
|
||||
return runStart - prototypeStart == delta;
|
||||
});
|
||||
}
|
||||
|
||||
static bool canMergeFragmentAssemblyCopyRunIntoFamily(const FragmentAssemblyCopyRunFamily& family,
|
||||
const FragmentAssemblyCopyRun& run,
|
||||
int64_t& sourceRunStartDelta,
|
||||
int64_t& hostRunStartDelta) {
|
||||
const FragmentAssemblyCopyRun& prototype = family.prototype;
|
||||
if (prototype.source != run.source || prototype.sourceType != run.sourceType
|
||||
|| prototype.hostTargetIndex != run.hostTargetIndex || prototype.count != run.count
|
||||
|| prototype.sourceStepBytes != run.sourceStepBytes || prototype.hostStepBytes != run.hostStepBytes
|
||||
|| prototype.byteSize != run.byteSize)
|
||||
return false;
|
||||
|
||||
if (!computeUniformRunStartDelta(prototype.sourceStartBytesByLane, run.sourceStartBytesByLane, sourceRunStartDelta))
|
||||
return false;
|
||||
return computeUniformRunStartDelta(prototype.hostStartBytesByLane, run.hostStartBytesByLane, hostRunStartDelta);
|
||||
}
|
||||
|
||||
static SmallVector<FragmentAssemblyCopyRunFamily, 8>
|
||||
groupFragmentAssemblyCopyRunFamilies(ArrayRef<FragmentAssemblyCopyRun> runs) {
|
||||
auto compareRunStarts = [](ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
|
||||
return std::lexicographical_compare(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
|
||||
};
|
||||
|
||||
SmallVector<FragmentAssemblyCopyRun, 8> sortedRuns(runs.begin(), runs.end());
|
||||
llvm::sort(sortedRuns, [&](const FragmentAssemblyCopyRun& lhs, const FragmentAssemblyCopyRun& rhs) {
|
||||
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
|
||||
return lhs.hostTargetIndex < rhs.hostTargetIndex;
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
|
||||
if (lhs.byteSize != rhs.byteSize)
|
||||
return lhs.byteSize < rhs.byteSize;
|
||||
if (lhs.count != rhs.count)
|
||||
return lhs.count < rhs.count;
|
||||
if (lhs.sourceStepBytes != rhs.sourceStepBytes)
|
||||
return lhs.sourceStepBytes < rhs.sourceStepBytes;
|
||||
if (lhs.hostStepBytes != rhs.hostStepBytes)
|
||||
return lhs.hostStepBytes < rhs.hostStepBytes;
|
||||
if (compareRunStarts(lhs.sourceStartBytesByLane, rhs.sourceStartBytesByLane))
|
||||
return true;
|
||||
if (compareRunStarts(rhs.sourceStartBytesByLane, lhs.sourceStartBytesByLane))
|
||||
return false;
|
||||
return compareRunStarts(lhs.hostStartBytesByLane, rhs.hostStartBytesByLane);
|
||||
});
|
||||
|
||||
SmallVector<FragmentAssemblyCopyRunFamily, 8> families;
|
||||
for (const FragmentAssemblyCopyRun& run : sortedRuns) {
|
||||
int64_t sourceRunStartDelta = 0;
|
||||
int64_t hostRunStartDelta = 0;
|
||||
if (!families.empty()
|
||||
&& canMergeFragmentAssemblyCopyRunIntoFamily(
|
||||
families.back(), run, sourceRunStartDelta, hostRunStartDelta)) {
|
||||
families.back().sourceRunStartDeltas.push_back(sourceRunStartDelta);
|
||||
families.back().hostRunStartDeltas.push_back(hostRunStartDelta);
|
||||
continue;
|
||||
}
|
||||
|
||||
FragmentAssemblyCopyRunFamily family;
|
||||
family.prototype = run;
|
||||
family.sourceRunStartDeltas.push_back(0);
|
||||
family.hostRunStartDeltas.push_back(0);
|
||||
families.push_back(std::move(family));
|
||||
}
|
||||
|
||||
return families;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>>
|
||||
groupFragmentAssemblyCopyRuns(ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount) {
|
||||
if (laneCount == 0)
|
||||
return failure();
|
||||
|
||||
struct LaneLocalCopyRun {
|
||||
FragmentAssemblyCopyRun run;
|
||||
int64_t lane = 0;
|
||||
};
|
||||
|
||||
SmallVector<FragmentAssemblyCopy, 8> sortedCopies(copies.begin(), copies.end());
|
||||
llvm::sort(sortedCopies, [](const FragmentAssemblyCopy& lhs, const FragmentAssemblyCopy& rhs) {
|
||||
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
|
||||
return lhs.hostTargetIndex < rhs.hostTargetIndex;
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
|
||||
if (lhs.lane != rhs.lane)
|
||||
return lhs.lane < rhs.lane;
|
||||
if (lhs.byteSize != rhs.byteSize)
|
||||
return lhs.byteSize < rhs.byteSize;
|
||||
if (lhs.sourceByteOffset != rhs.sourceByteOffset)
|
||||
return lhs.sourceByteOffset < rhs.sourceByteOffset;
|
||||
return lhs.hostByteOffset < rhs.hostByteOffset;
|
||||
});
|
||||
|
||||
SmallVector<LaneLocalCopyRun, 8> laneRuns;
|
||||
for (const FragmentAssemblyCopy& copy : sortedCopies) {
|
||||
if (copy.lane < 0 || copy.lane >= static_cast<int64_t>(laneCount))
|
||||
return failure();
|
||||
|
||||
if (!laneRuns.empty()) {
|
||||
LaneLocalCopyRun& laneRun = laneRuns.back();
|
||||
FragmentAssemblyCopyRun& run = laneRun.run;
|
||||
if (run.source == copy.source && run.sourceType == copy.sourceType
|
||||
&& run.hostTargetIndex == copy.hostTargetIndex && laneRun.lane == copy.lane && run.byteSize == copy.byteSize
|
||||
&& run.sourceStartBytesByLane.size() == 1 && run.hostStartBytesByLane.size() == 1) {
|
||||
int64_t previousSourceOffset = run.sourceStartBytesByLane.front() + (run.count - 1) * run.sourceStepBytes;
|
||||
int64_t previousHostOffset = run.hostStartBytesByLane.front() + (run.count - 1) * run.hostStepBytes;
|
||||
int64_t sourceDelta = copy.sourceByteOffset - previousSourceOffset;
|
||||
int64_t hostDelta = copy.hostByteOffset - previousHostOffset;
|
||||
if (run.count == 1) {
|
||||
run.sourceStepBytes = sourceDelta;
|
||||
run.hostStepBytes = hostDelta;
|
||||
++run.count;
|
||||
continue;
|
||||
}
|
||||
if (run.sourceStepBytes == sourceDelta && run.hostStepBytes == hostDelta) {
|
||||
++run.count;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LaneLocalCopyRun laneRun;
|
||||
laneRun.run.source = copy.source;
|
||||
laneRun.run.sourceType = copy.sourceType;
|
||||
laneRun.run.hostTargetIndex = copy.hostTargetIndex;
|
||||
laneRun.run.count = 1;
|
||||
laneRun.run.byteSize = copy.byteSize;
|
||||
laneRun.run.sourceStartBytesByLane.push_back(copy.sourceByteOffset);
|
||||
laneRun.run.hostStartBytesByLane.push_back(copy.hostByteOffset);
|
||||
laneRun.lane = copy.lane;
|
||||
laneRuns.push_back(std::move(laneRun));
|
||||
}
|
||||
|
||||
SmallVector<FragmentAssemblyCopyRun, 8> mergedRuns;
|
||||
for (const LaneLocalCopyRun& laneRun : laneRuns) {
|
||||
size_t laneIndex = static_cast<size_t>(laneRun.lane);
|
||||
auto mergedIt = llvm::find_if(mergedRuns, [&](const FragmentAssemblyCopyRun& run) {
|
||||
return run.source == laneRun.run.source && run.sourceType == laneRun.run.sourceType
|
||||
&& run.hostTargetIndex == laneRun.run.hostTargetIndex && run.count == laneRun.run.count
|
||||
&& run.byteSize == laneRun.run.byteSize && run.sourceStepBytes == laneRun.run.sourceStepBytes
|
||||
&& run.hostStepBytes == laneRun.run.hostStepBytes && laneIndex < run.sourceStartBytesByLane.size()
|
||||
&& run.sourceStartBytesByLane[laneIndex] == std::numeric_limits<int64_t>::min();
|
||||
});
|
||||
|
||||
if (mergedIt == mergedRuns.end()) {
|
||||
FragmentAssemblyCopyRun merged = laneRun.run;
|
||||
merged.sourceStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||
merged.hostStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||
merged.sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
|
||||
merged.hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
|
||||
mergedRuns.push_back(std::move(merged));
|
||||
continue;
|
||||
}
|
||||
|
||||
mergedIt->sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
|
||||
mergedIt->hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
|
||||
}
|
||||
|
||||
for (const FragmentAssemblyCopyRun& run : mergedRuns) {
|
||||
if (llvm::any_of(run.sourceStartBytesByLane,
|
||||
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
|
||||
return failure();
|
||||
if (llvm::any_of(run.hostStartBytesByLane,
|
||||
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return mergedRuns;
|
||||
}
|
||||
|
||||
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRun(OpBuilder& builder,
|
||||
Location loc,
|
||||
const FragmentAssemblyCopyRun& run,
|
||||
mlir::Value hostTarget,
|
||||
Operation* anchor,
|
||||
std::optional<mlir::Value> laneArg,
|
||||
mlir::Value baseHostOffset,
|
||||
mlir::Value sourceRunStartDelta = {},
|
||||
mlir::Value hostRunStartDelta = {}) {
|
||||
auto sizeAttr = pim::getCheckedI32Attr(builder, anchor, run.byteSize, "fragment assembly host copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
mlir::Value hostStart;
|
||||
mlir::Value sourceStart;
|
||||
if (laneArg) {
|
||||
hostStart = createIndexedOffset(builder, loc, *laneArg, run.hostStartBytesByLane);
|
||||
sourceStart = createIndexedOffset(builder, loc, *laneArg, run.sourceStartBytesByLane);
|
||||
} else {
|
||||
hostStart = arith::ConstantIndexOp::create(builder, loc, run.hostStartBytesByLane.front());
|
||||
sourceStart = arith::ConstantIndexOp::create(builder, loc, run.sourceStartBytesByLane.front());
|
||||
}
|
||||
|
||||
if (hostRunStartDelta)
|
||||
hostStart = arith::AddIOp::create(builder, loc, hostStart, hostRunStartDelta).getResult();
|
||||
if (sourceRunStartDelta)
|
||||
sourceStart = arith::AddIOp::create(builder, loc, sourceStart, sourceRunStartDelta).getResult();
|
||||
if (baseHostOffset)
|
||||
hostStart = arith::AddIOp::create(builder, loc, baseHostOffset, hostStart).getResult();
|
||||
|
||||
if (run.count == 1) {
|
||||
return pim::PimMemCopyDevToHostOp::create(builder,
|
||||
loc,
|
||||
hostTarget.getType(),
|
||||
hostStart,
|
||||
sourceStart,
|
||||
hostTarget,
|
||||
run.source,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, run.count);
|
||||
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
|
||||
FailureOr<NormalizedLoopResult> loop = buildNormalizedScfFor(
|
||||
builder,
|
||||
loc,
|
||||
lowerBound,
|
||||
upperBound,
|
||||
step,
|
||||
ValueRange {hostTarget},
|
||||
[&](OpBuilder& loopBuilder,
|
||||
Location bodyLoc,
|
||||
mlir::Value flatIndex,
|
||||
ValueRange iterArgs,
|
||||
SmallVectorImpl<mlir::Value>& yielded) {
|
||||
mlir::Value hostOffset = createSteppedOffset(loopBuilder, bodyLoc, hostStart, flatIndex, run.hostStepBytes);
|
||||
mlir::Value sourceOffset =
|
||||
createSteppedOffset(loopBuilder, bodyLoc, sourceStart, flatIndex, run.sourceStepBytes);
|
||||
mlir::Value copied =
|
||||
pim::PimMemCopyDevToHostOp::create(loopBuilder,
|
||||
bodyLoc,
|
||||
iterArgs.front().getType(),
|
||||
hostOffset,
|
||||
sourceOffset,
|
||||
iterArgs.front(),
|
||||
run.source,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
yielded.push_back(copied);
|
||||
return success();
|
||||
});
|
||||
if (failed(loop))
|
||||
return failure();
|
||||
return loop->results.front();
|
||||
}
|
||||
|
||||
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRunFamily(OpBuilder& builder,
|
||||
Location loc,
|
||||
const FragmentAssemblyCopyRunFamily& family,
|
||||
mlir::Value hostTarget,
|
||||
Operation* anchor,
|
||||
std::optional<mlir::Value> laneArg,
|
||||
mlir::Value baseHostOffset) {
|
||||
if (family.sourceRunStartDeltas.size() == 1)
|
||||
return emitFragmentAssemblyCopyRun(
|
||||
builder, loc, family.prototype, hostTarget, anchor, laneArg, baseHostOffset);
|
||||
|
||||
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
|
||||
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, family.sourceRunStartDeltas.size());
|
||||
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
|
||||
FailureOr<NormalizedLoopResult> outerLoop = buildNormalizedScfFor(
|
||||
builder,
|
||||
loc,
|
||||
lowerBound,
|
||||
upperBound,
|
||||
step,
|
||||
ValueRange {hostTarget},
|
||||
[&](OpBuilder& loopBuilder,
|
||||
Location bodyLoc,
|
||||
mlir::Value runIndex,
|
||||
ValueRange iterArgs,
|
||||
SmallVectorImpl<mlir::Value>& yielded) {
|
||||
mlir::Value sourceRunStartDelta =
|
||||
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.sourceRunStartDeltas);
|
||||
mlir::Value hostRunStartDelta =
|
||||
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.hostRunStartDeltas);
|
||||
FailureOr<mlir::Value> copied = emitFragmentAssemblyCopyRun(loopBuilder,
|
||||
bodyLoc,
|
||||
family.prototype,
|
||||
iterArgs.front(),
|
||||
anchor,
|
||||
laneArg,
|
||||
baseHostOffset,
|
||||
sourceRunStartDelta,
|
||||
hostRunStartDelta);
|
||||
if (failed(copied))
|
||||
return failure();
|
||||
yielded.push_back(*copied);
|
||||
return success();
|
||||
});
|
||||
if (failed(outerLoop))
|
||||
return failure();
|
||||
return outerLoop->results.front();
|
||||
}
|
||||
|
||||
FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ArrayRef<FragmentAssemblyCopyRun> runs,
|
||||
mlir::Value hostTarget,
|
||||
Operation* anchor,
|
||||
std::optional<mlir::Value> laneArg,
|
||||
mlir::Value baseHostOffset) {
|
||||
for (const FragmentAssemblyCopyRunFamily& family : groupFragmentAssemblyCopyRunFamilies(runs)) {
|
||||
FailureOr<mlir::Value> updatedHostTarget =
|
||||
emitFragmentAssemblyCopyRunFamily(rewriter, loc, family, hostTarget, anchor, laneArg, baseHostOffset);
|
||||
if (failed(updatedHostTarget))
|
||||
return failure();
|
||||
hostTarget = *updatedHostTarget;
|
||||
}
|
||||
|
||||
return hostTarget;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
@@ -59,6 +65,39 @@ forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
|
||||
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
|
||||
callback);
|
||||
|
||||
struct FragmentAssemblyCopy {
|
||||
mlir::Value source;
|
||||
mlir::RankedTensorType sourceType;
|
||||
unsigned hostTargetIndex = 0;
|
||||
int64_t lane = 0;
|
||||
int64_t sourceByteOffset = 0;
|
||||
int64_t hostByteOffset = 0;
|
||||
int64_t byteSize = 0;
|
||||
};
|
||||
|
||||
struct FragmentAssemblyCopyRun {
|
||||
mlir::Value source;
|
||||
mlir::RankedTensorType sourceType;
|
||||
unsigned hostTargetIndex = 0;
|
||||
int64_t count = 0;
|
||||
int64_t sourceStepBytes = 0;
|
||||
int64_t hostStepBytes = 0;
|
||||
int64_t byteSize = 0;
|
||||
mlir::SmallVector<int64_t, 8> sourceStartBytesByLane;
|
||||
mlir::SmallVector<int64_t, 8> hostStartBytesByLane;
|
||||
};
|
||||
|
||||
mlir::FailureOr<mlir::SmallVector<FragmentAssemblyCopyRun, 8>>
|
||||
groupFragmentAssemblyCopyRuns(llvm::ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount = 1);
|
||||
|
||||
mlir::FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(mlir::IRRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
llvm::ArrayRef<FragmentAssemblyCopyRun> runs,
|
||||
mlir::Value hostTarget,
|
||||
mlir::Operation* anchor,
|
||||
std::optional<mlir::Value> laneArg = std::nullopt,
|
||||
mlir::Value baseHostOffset = {});
|
||||
|
||||
inline mlir::tensor::EmptyOp
|
||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -8,6 +9,7 @@
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
@@ -30,22 +32,9 @@ static bool isChannelUseChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
ShapedType destinationType,
|
||||
ArrayRef<int64_t> fragmentOffsets) {
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||
|
||||
int64_t byteOffset = 0;
|
||||
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
|
||||
byteOffset += offset * strides[dim] * elementBytes;
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
||||
}
|
||||
|
||||
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
||||
spatial::SpatBlueprintOp blueprint,
|
||||
IRMapping& mapping) {
|
||||
spatial::SpatBlueprintOp blueprint,
|
||||
IRMapping& mapping) {
|
||||
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||
@@ -77,56 +66,54 @@ static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
|
||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||
}
|
||||
|
||||
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||
if (failed(forEachContiguousDestinationChunk(
|
||||
resultType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
int64_t hostElementOffset = 0;
|
||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||
hostElementOffset += offset * hostStrides[dim];
|
||||
|
||||
int64_t fragmentBytes =
|
||||
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
||||
blueprint.getOperation(),
|
||||
fragmentBytes,
|
||||
"fragment assembly host copy size");
|
||||
if (failed(sizeAttr))
|
||||
FragmentAssemblyCopy copy;
|
||||
copy.source = source;
|
||||
copy.sourceType = sourceType;
|
||||
copy.sourceByteOffset =
|
||||
(sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
|
||||
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||
copies.push_back(copy);
|
||||
return success();
|
||||
})))
|
||||
return failure();
|
||||
|
||||
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
|
||||
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
||||
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
||||
blueprint,
|
||||
"fragment assembly device source offset");
|
||||
if (failed(deviceSourceOffsetBytes))
|
||||
return failure();
|
||||
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
||||
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
blueprint.getLoc(),
|
||||
currentOutput.getType(),
|
||||
hostTargetOffset,
|
||||
deviceSourceOffset,
|
||||
currentOutput,
|
||||
source,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
return currentOutput;
|
||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
||||
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
|
||||
if (failed(runs))
|
||||
return failure();
|
||||
return emitFragmentAssemblyCopyRuns(
|
||||
rewriter, blueprint.getLoc(), *runs, currentOutput, blueprint.getOperation());
|
||||
}
|
||||
|
||||
static void
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
@@ -638,6 +640,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
SmallVector<FragmentAssemblyCopy, 8> copies;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
|
||||
continue;
|
||||
@@ -675,29 +678,27 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
failedChunk = true;
|
||||
return failure();
|
||||
}
|
||||
auto sizeAttr =
|
||||
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size");
|
||||
if (failed(sizeAttr)) {
|
||||
failedChunk = true;
|
||||
return failure();
|
||||
}
|
||||
|
||||
outputTensor =
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
blueprint.getLoc(),
|
||||
outputTensor.getType(),
|
||||
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
||||
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
FragmentAssemblyCopy copy;
|
||||
copy.source = storedValue;
|
||||
copy.sourceType = sourceType;
|
||||
copy.hostByteOffset = *hostOffset;
|
||||
copy.sourceByteOffset = *sourceOffset;
|
||||
copy.byteSize = *fragmentBytes;
|
||||
copies.push_back(copy);
|
||||
return success();
|
||||
})))
|
||||
failedChunk = true;
|
||||
if (failedChunk)
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
|
||||
if (failed(runs))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
FailureOr<Value> updatedOutput =
|
||||
emitFragmentAssemblyCopyRuns(rewriter, blueprint.getLoc(), *runs, outputTensor, producerOp);
|
||||
if (failed(updatedOutput))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
outputTensor = *updatedOutput;
|
||||
markOpToRemove(blueprint.getOperation());
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
|
||||
+686
-140
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user