robba
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-29 12:22:33 +02:00
parent 78e97f9fd8
commit e8f09fd67f
8 changed files with 1376 additions and 492 deletions
+18 -30
View File
@@ -34,12 +34,25 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value); llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value); llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
template <typename... Args> template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) { CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...)); return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
} }
static mlir::Value resolveForYieldedAliasToInit(mlir::scf::ForOp forOp,
mlir::Value yieldedValue,
const StaticValueKnowledge* knowledge) {
yieldedValue = resolveLoopCarriedAliasImpl(yieldedValue, knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
}
return yieldedValue;
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge); value = resolveAlias(value, knowledge);
@@ -60,15 +73,8 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
auto result = mlir::dyn_cast<mlir::OpResult>(value); auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (result) { if (result) {
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator()); auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) { if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands())
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); return resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
}
return yieldedValue;
}
} }
} }
@@ -515,16 +521,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
return mlir::failure(); return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator()); auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue; continue;
} }
@@ -643,16 +640,7 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
return mlir::failure(); return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator()); auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber()); value = resolveForYieldedAliasToInit(forOp, yieldOp.getOperand(result.getResultNumber()), nullptr);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue; continue;
} }
@@ -862,7 +850,7 @@ llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const S
auto resolvedOffset = byteOffset.evaluate(knowledge); auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset)) if (failed(resolvedOffset))
return mlir::failure(); return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset}; return ResolvedContiguousAddress {resolveAlias(base, &knowledge), *resolvedOffset};
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1334,6 +1334,38 @@ static Value affineAddConst(
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
} }
static Value affineMulConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
if (factor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
}
static Value affineFloorDivConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(divisor > 0 && "expected positive affine floordiv divisor");
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
static Value affineModConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
assert(modulus > 0 && "expected positive affine mod divisor");
if (modulus == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
}
static Value createConvInputPatch(Value input, static Value createConvInputPatch(Value input,
RankedTensorType patchType, RankedTensorType patchType,
Value batchIndex, Value batchIndex,
@@ -2414,11 +2446,6 @@ static Value createIm2colRows(const ConvLoweringState& state,
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches); Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart);
Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch);
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth);
auto im2colLoop = buildNormalizedScfFor( auto im2colLoop = buildNormalizedScfFor(
rewriter, rewriter,
@@ -2429,13 +2456,17 @@ static Value createIm2colRows(const ConvLoweringState& state,
ValueRange {im2colInit}, ValueRange {im2colInit},
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) { [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value im2colAcc = iterArgs.front(); Value im2colAcc = iterArgs.front();
Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart); Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); Value batchIndex =
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); Value batchPatchIndex =
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value inputHeightOffset =
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
Value inputWidthOffset =
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
auto patchType = auto patchType =
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
@@ -2844,11 +2875,9 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn); Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth); Value localHeightOffset = args.lane;
Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1);
Value packedRowInit = Value packedRowInit =
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType); tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
auto widthLoop = buildNormalizedScfFor( auto widthLoop = buildNormalizedScfFor(
@@ -2859,7 +2888,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
c1, c1,
ValueRange {packedRowInit}, ValueRange {packedRowInit},
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) { [&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1); Value localWidthOffset = widthIndex;
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType); Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
auto rowLoop = buildNormalizedScfFor( auto rowLoop = buildNormalizedScfFor(
rewriter, rewriter,
@@ -2878,7 +2907,8 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}}); rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
Value rowChunk = tensor::ExpandShapeOp::create( Value rowChunk = tensor::ExpandShapeOp::create(
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}}); rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth); Value flatOffset = affineMulConst(
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset}; SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
SmallVector<OpFoldResult> rowSizes { SmallVector<OpFoldResult> rowSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
@@ -2905,7 +2935,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
c1, c1,
ValueRange {zeroRow}, ValueRange {zeroRow},
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) { [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar); Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset}; SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
Value aTile = tensor::ExtractSliceOp::create( Value aTile = tensor::ExtractSliceOp::create(
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@@ -10,6 +11,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -131,46 +133,92 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
return result.getUses().begin()->getOperandNumber(); return result.getUses().begin()->getOperandNumber();
} }
struct BatchFragmentAssemblyPlan { static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
unsigned returnIndex = 0; collectFragmentAssemblyCopiesFromBlueprint(spatial::SpatBlueprintOp blueprint,
int64_t localSourceElementOffset = 0; IRMapping& mapper,
int64_t fragmentByteSize = 0; int64_t lane,
SmallVector<int64_t, 8> hostOffsetsByLane; unsigned hostTargetIndex,
}; Value fixedSource = {}) {
SmallVector<FragmentAssemblyCopy, 8> copies;
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) { std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
assert(!values.empty() && "expected lane-indexed values"); std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); })) if (!operandIndicesAttr || !fragmentStridesAttr)
return getOrCreateIndexConstant(rewriter, anchor, values.front()); return blueprint.emitOpError(
"fragment assembly lowering requires explicit operand indices and unit strides");
if (values.size() >= 2) { ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
int64_t step = values[1] - values[0]; std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) { if (!sourceOffsetsAttr)
return values[index] == values.front() + static_cast<int64_t>(index) * step; return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
}); ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
if (arithmetic) { ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front()); ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
if (step == 0) ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
return base; int64_t rank = resultType.getRank();
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult(); SmallVector<Value> fragmentOperands {blueprint.getInput()};
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult(); llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
fragmentOperands.size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return failure();
SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
Value source = fixedSource ? fixedSource : mapper.lookupOrDefault(fragmentOperands[operandIndices[fragmentIndex]]);
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentSizes.push_back(flatSizes[flatIndex]);
} }
if (failed(forEachContiguousDestinationChunk(
resultType.getShape(),
fragmentOffsets,
fragmentSizes,
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
int64_t hostElementOffset = 0;
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
hostElementOffset += offset * hostStrides[dim];
FragmentAssemblyCopy copy;
copy.source = source;
copy.sourceType = sourceType;
copy.hostTargetIndex = hostTargetIndex;
copy.lane = lane;
copy.sourceByteOffset = (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
copies.push_back(copy);
return success();
})))
return failure();
} }
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front()); return copies;
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
}
return selected;
} }
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> static FailureOr<SmallVector<FragmentAssemblyCopy, 8>>
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) { collectTopLevelFragmentAssemblyCopies(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
SmallVector<BatchFragmentAssemblyPlan, 8> plans; SmallVector<FragmentAssemblyCopy, 8> copies;
if (!packedResultType.hasStaticShape() || laneCount == 0) if (!packedResultType.hasStaticShape() || laneCount == 0)
return failure(); return failure();
@@ -187,15 +235,14 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
std::optional<StringRef> mode = blueprint.getMode(); std::optional<StringRef> mode = blueprint.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices(); std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets(); std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides(); if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr)
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
return failure(); return failure();
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin())) if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
return failure(); return failure();
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType()); auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!hostResultType || !hostResultType.hasStaticShape()) std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
if (!hostResultType || !hostResultType.hasStaticShape() || !stridesAttr)
return failure(); return failure();
ArrayRef<int64_t> operandIndices = *operandIndicesAttr; ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
@@ -204,6 +251,7 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes(); ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *stridesAttr; ArrayRef<int64_t> flatStrides = *stridesAttr;
int64_t rank = hostResultType.getRank(); int64_t rank = hostResultType.getRank();
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
SmallVector<Value> fragmentOperands {blueprint.getInput()}; SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments()); llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint, if (failed(validateFragmentAssemblyMetadata(blueprint,
@@ -215,16 +263,15 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
flatSizes, flatSizes,
flatStrides))) flatStrides)))
return failure(); return failure();
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) { for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber())) if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
continue; continue;
int64_t sourceElementOffset = sourceOffsets[fragmentIndex]; int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
int64_t lane = sourceElementOffset / payloadElementCount; int64_t lane = sourceElementOffset / payloadElementCount;
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
if (lane < 0 || lane >= static_cast<int64_t>(laneCount)) if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
return failure(); return failure();
SmallVector<int64_t, 4> fragmentOffsets; SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes; SmallVector<int64_t, 4> fragmentSizes;
for (int64_t dim = 0; dim < rank; ++dim) { for (int64_t dim = 0; dim < rank; ++dim) {
@@ -236,44 +283,31 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
} }
if (failed(forEachContiguousDestinationChunk( if (failed(forEachContiguousDestinationChunk(
hostResultType.getShape(), hostResultType.getShape(),
fragmentOffsets, fragmentOffsets,
fragmentSizes, fragmentSizes,
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult { [&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
int64_t hostElementOffset = 0; int64_t hostElementOffset = 0;
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape()); for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
for (auto [dim, offset] : llvm::enumerate(chunkOffsets)) hostElementOffset += offset * hostStrides[dim];
hostElementOffset += offset * hostStrides[dim];
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) { FragmentAssemblyCopy copy;
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset copy.source = result;
&& plan.fragmentByteSize == fragmentByteSize; copy.sourceType = packedResultType;
}); copy.hostTargetIndex = returnIndex;
if (planIt == plans.end()) { copy.lane = lane;
BatchFragmentAssemblyPlan plan; copy.sourceByteOffset =
plan.returnIndex = returnIndex; ((sourceElementOffset % payloadElementCount) + relativeSourceOffset) * static_cast<int64_t>(elementSize);
plan.localSourceElementOffset = chunkSourceOffset; copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
plan.fragmentByteSize = fragmentByteSize; copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min()); copies.push_back(copy);
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset; return success();
plans.push_back(std::move(plan)); })))
return success();
}
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
return success();
})))
return failure(); return failure();
} }
} }
for (const BatchFragmentAssemblyPlan& plan : plans) return copies;
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
return failure();
return plans;
} }
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) { static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
@@ -284,22 +318,6 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
} }
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
static Value createHostTargetOffset(IRRewriter& rewriter, static Value createHostTargetOffset(IRRewriter& rewriter,
Location loc, Location loc,
ShapedType destinationType, ShapedType destinationType,
@@ -351,123 +369,6 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
mapper); mapper);
} }
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
Location loc,
ArrayRef<OpFoldResult> baseOffsets,
ArrayRef<int64_t> fragmentOffsets,
IRMapping& mapper) {
SmallVector<OpFoldResult, 4> combined;
combined.reserve(fragmentOffsets.size());
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
int64_t base = cast<IntegerAttr>(attr).getInt();
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
continue;
}
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
if (fragmentOffsets[dim] == 0) {
combined.push_back(dynamicBase);
continue;
}
Value staticOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
}
return combined;
}
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
spatial::SpatBlueprintOp blueprint,
Value hostTarget,
ArrayRef<OpFoldResult> baseOffsets,
IRMapping& mapper) {
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return blueprint.emitOpError(
"fragment assembly lowering requires explicit operand indices and unit strides");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
if (!sourceOffsetsAttr)
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
fragmentOperands.size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return failure();
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
SmallVector<int64_t, 4> fragmentOffsets;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
}
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
blueprint.getLoc(),
source,
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
hostTarget = tensor::InsertSliceOp::create(rewriter,
blueprint.getLoc(),
fragment,
hostTarget,
buildFragmentOffsets(rewriter,
blueprint.getLoc(),
baseOffsets,
fragmentOffsets,
mapper),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
return hostTarget;
}
} // namespace } // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
@@ -505,10 +406,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds)); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
SmallVector<unsigned> returnOperandIndices; SmallVector<unsigned> returnOperandIndices;
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult; SmallVector<SmallVector<FragmentAssemblyCopyRun, 1>, 4> fragmentAssemblyRunsByResult;
if (computeBatchOp.getNumResults() != 0) { if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max()); returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults()); fragmentAssemblyRunsByResult.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
if (result.use_empty()) if (result.use_empty())
continue; continue;
@@ -522,12 +423,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return computeBatchOp.emitOpError( return computeBatchOp.emitOpError(
"resultful compute_batch publication lowering requires static ranked tensor results"); "resultful compute_batch publication lowering requires static ranked tensor results");
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans = FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount()); collectTopLevelFragmentAssemblyCopies(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
if (failed(fragmentAssemblyPlans)) if (failed(fragmentAssemblyCopies))
return computeBatchOp.emitOpError( return computeBatchOp.emitOpError("failed to collect top-level fragment assembly copies for compute_batch result");
"resultful compute_batch lowering currently requires each result to be used directly by func.return"); FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end()); groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, computeBatchOp.getLaneCount());
if (failed(fragmentAssemblyRuns))
return computeBatchOp.emitOpError("failed to group top-level fragment assembly copies into regular runs");
fragmentAssemblyRunsByResult[resultIndex].assign(fragmentAssemblyRuns->begin(), fragmentAssemblyRuns->end());
} }
} }
@@ -614,8 +518,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (resultIndex >= returnOperandIndices.size()) if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output"); return insertSlice.emitOpError("result index out of range while lowering host batch output");
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max(); bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size() bool hasFragmentAssembly = resultIndex < fragmentAssemblyRunsByResult.size()
&& !fragmentAssemblyPlansByResult[resultIndex].empty(); && !fragmentAssemblyRunsByResult[resultIndex].empty();
if (!hasDirectReturn && !hasFragmentAssembly) if (!hasDirectReturn && !hasFragmentAssembly)
continue; continue;
@@ -626,27 +530,23 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType()); auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
if (!mappedSourceType || !mappedSourceType.hasStaticShape()) if (!mappedSourceType || !mappedSourceType.hasStaticShape())
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source"); return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) { DenseMap<unsigned, Value> updatedOutputs;
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc()); for (const FragmentAssemblyCopyRun& run : fragmentAssemblyRunsByResult[resultIndex]) {
auto sizeAttr = pim::getCheckedI32Attr( Value outputTensor = updatedOutputs.lookup(run.hostTargetIndex);
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size"); if (!outputTensor)
if (failed(sizeAttr)) outputTensor = outputTensors[run.hostTargetIndex](rewriter, insertSlice.getLoc());
FragmentAssemblyCopyRun mappedRun = run;
mappedRun.source = mappedSource;
FailureOr<Value> updated =
emitFragmentAssemblyCopyRuns(rewriter,
insertSlice.getLoc(),
ArrayRef<FragmentAssemblyCopyRun> {mappedRun},
outputTensor,
coreBatchOp.getOperation(),
laneArg);
if (failed(updated))
return failure(); return failure();
Value hostTargetOffset = updatedOutputs[run.hostTargetIndex] = *updated;
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
Value deviceSourceOffset = getOrCreateIndexConstant(
rewriter, coreBatchOp.getOperation(),
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
outputTensor =
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
outputTensor.getType(),
hostTargetOffset,
deviceSourceOffset,
outputTensor,
mappedSource,
*sizeAttr)
.getOutput();
} }
continue; continue;
} }
@@ -657,11 +557,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) { insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
std::optional<StringRef> modeAttr = blueprint.getMode(); std::optional<StringRef> modeAttr = blueprint.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") { if (modeAttr && *modeAttr == "fragment_assembly") {
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter, FailureOr<SmallVector<FragmentAssemblyCopy, 8>> fragmentAssemblyCopies =
blueprint, collectFragmentAssemblyCopiesFromBlueprint(blueprint, mapper, /*lane=*/0, /*hostTargetIndex=*/0);
hostTarget, if (failed(fragmentAssemblyCopies))
insertSlice.getMixedOffsets(), return failure();
mapper); FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> fragmentAssemblyRuns =
groupFragmentAssemblyCopyRuns(*fragmentAssemblyCopies, /*laneCount=*/1);
if (failed(fragmentAssemblyRuns))
return failure();
SmallVector<int64_t> zeroOffsets(hostTargetType.getRank(), 0);
Value baseHostOffset = createHostTargetOffset(rewriter,
blueprint.getLoc(),
hostTargetType,
insertSlice.getMixedOffsets(),
zeroOffsets,
mapper);
FailureOr<Value> updatedHostTarget = emitFragmentAssemblyCopyRuns(rewriter,
blueprint.getLoc(),
*fragmentAssemblyRuns,
hostTarget,
coreBatchOp.getOperation(),
std::nullopt,
baseHostOffset);
if (failed(updatedHostTarget)) if (failed(updatedHostTarget))
return failure(); return failure();
hostOutputTensors[resultIndex] = *updatedHostTarget; hostOutputTensors[resultIndex] = *updatedHostTarget;
+376
View File
@@ -1,12 +1,17 @@
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include <cassert> #include <cassert>
#include <limits>
#include "Common.hpp" #include "Common.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
@@ -186,4 +191,375 @@ forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
return visit(visit, 0); return visit(visit, 0);
} }
static mlir::Value
createSteppedOffset(OpBuilder& builder, Location loc, mlir::Value start, mlir::Value index, int64_t stepBytes) {
if (stepBytes == 0)
return start;
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, stepBytes);
mlir::Value scaled = arith::MulIOp::create(builder, loc, index, step).getResult();
return arith::AddIOp::create(builder, loc, start, scaled).getResult();
}
static mlir::Value createIndexedOffset(OpBuilder& builder,
Location loc,
mlir::Value indexArg,
ArrayRef<int64_t> values) {
assert(!values.empty() && "expected lane-indexed values");
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
return arith::ConstantIndexOp::create(builder, loc, values.front());
if (values.size() >= 2) {
int64_t step = values[1] - values[0];
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
return values[index] == values.front() + static_cast<int64_t>(index) * step;
});
if (arithmetic) {
mlir::Value base = arith::ConstantIndexOp::create(builder, loc, values.front());
mlir::Value stepValue = arith::ConstantIndexOp::create(builder, loc, step);
mlir::Value scaledIndex = arith::MulIOp::create(builder, loc, indexArg, stepValue).getResult();
return arith::AddIOp::create(builder, loc, base, scaledIndex).getResult();
}
}
mlir::Value selected = arith::ConstantIndexOp::create(builder, loc, values.front());
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
mlir::Value indexValue = arith::ConstantIndexOp::create(builder, loc, static_cast<int64_t>(lane + 1));
mlir::Value cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, indexArg, indexValue);
mlir::Value candidate = arith::ConstantIndexOp::create(builder, loc, value);
selected = arith::SelectOp::create(builder, loc, cmp, candidate, selected);
}
return selected;
}
struct FragmentAssemblyCopyRunFamily {
FragmentAssemblyCopyRun prototype;
SmallVector<int64_t, 8> sourceRunStartDeltas;
SmallVector<int64_t, 8> hostRunStartDeltas;
};
static bool computeUniformRunStartDelta(ArrayRef<int64_t> prototypeStarts,
ArrayRef<int64_t> runStarts,
int64_t& delta) {
if (prototypeStarts.size() != runStarts.size() || prototypeStarts.empty())
return false;
delta = runStarts.front() - prototypeStarts.front();
return llvm::all_of(llvm::zip_equal(prototypeStarts, runStarts), [&](auto pair) {
auto [prototypeStart, runStart] = pair;
return runStart - prototypeStart == delta;
});
}
static bool canMergeFragmentAssemblyCopyRunIntoFamily(const FragmentAssemblyCopyRunFamily& family,
const FragmentAssemblyCopyRun& run,
int64_t& sourceRunStartDelta,
int64_t& hostRunStartDelta) {
const FragmentAssemblyCopyRun& prototype = family.prototype;
if (prototype.source != run.source || prototype.sourceType != run.sourceType
|| prototype.hostTargetIndex != run.hostTargetIndex || prototype.count != run.count
|| prototype.sourceStepBytes != run.sourceStepBytes || prototype.hostStepBytes != run.hostStepBytes
|| prototype.byteSize != run.byteSize)
return false;
if (!computeUniformRunStartDelta(prototype.sourceStartBytesByLane, run.sourceStartBytesByLane, sourceRunStartDelta))
return false;
return computeUniformRunStartDelta(prototype.hostStartBytesByLane, run.hostStartBytesByLane, hostRunStartDelta);
}
static SmallVector<FragmentAssemblyCopyRunFamily, 8>
groupFragmentAssemblyCopyRunFamilies(ArrayRef<FragmentAssemblyCopyRun> runs) {
auto compareRunStarts = [](ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs) {
return std::lexicographical_compare(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());
};
SmallVector<FragmentAssemblyCopyRun, 8> sortedRuns(runs.begin(), runs.end());
llvm::sort(sortedRuns, [&](const FragmentAssemblyCopyRun& lhs, const FragmentAssemblyCopyRun& rhs) {
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
return lhs.hostTargetIndex < rhs.hostTargetIndex;
if (lhs.source != rhs.source)
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
if (lhs.byteSize != rhs.byteSize)
return lhs.byteSize < rhs.byteSize;
if (lhs.count != rhs.count)
return lhs.count < rhs.count;
if (lhs.sourceStepBytes != rhs.sourceStepBytes)
return lhs.sourceStepBytes < rhs.sourceStepBytes;
if (lhs.hostStepBytes != rhs.hostStepBytes)
return lhs.hostStepBytes < rhs.hostStepBytes;
if (compareRunStarts(lhs.sourceStartBytesByLane, rhs.sourceStartBytesByLane))
return true;
if (compareRunStarts(rhs.sourceStartBytesByLane, lhs.sourceStartBytesByLane))
return false;
return compareRunStarts(lhs.hostStartBytesByLane, rhs.hostStartBytesByLane);
});
SmallVector<FragmentAssemblyCopyRunFamily, 8> families;
for (const FragmentAssemblyCopyRun& run : sortedRuns) {
int64_t sourceRunStartDelta = 0;
int64_t hostRunStartDelta = 0;
if (!families.empty()
&& canMergeFragmentAssemblyCopyRunIntoFamily(
families.back(), run, sourceRunStartDelta, hostRunStartDelta)) {
families.back().sourceRunStartDeltas.push_back(sourceRunStartDelta);
families.back().hostRunStartDeltas.push_back(hostRunStartDelta);
continue;
}
FragmentAssemblyCopyRunFamily family;
family.prototype = run;
family.sourceRunStartDeltas.push_back(0);
family.hostRunStartDeltas.push_back(0);
families.push_back(std::move(family));
}
return families;
}
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>>
groupFragmentAssemblyCopyRuns(ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount) {
if (laneCount == 0)
return failure();
struct LaneLocalCopyRun {
FragmentAssemblyCopyRun run;
int64_t lane = 0;
};
SmallVector<FragmentAssemblyCopy, 8> sortedCopies(copies.begin(), copies.end());
llvm::sort(sortedCopies, [](const FragmentAssemblyCopy& lhs, const FragmentAssemblyCopy& rhs) {
if (lhs.hostTargetIndex != rhs.hostTargetIndex)
return lhs.hostTargetIndex < rhs.hostTargetIndex;
if (lhs.source != rhs.source)
return lhs.source.getAsOpaquePointer() < rhs.source.getAsOpaquePointer();
if (lhs.lane != rhs.lane)
return lhs.lane < rhs.lane;
if (lhs.byteSize != rhs.byteSize)
return lhs.byteSize < rhs.byteSize;
if (lhs.sourceByteOffset != rhs.sourceByteOffset)
return lhs.sourceByteOffset < rhs.sourceByteOffset;
return lhs.hostByteOffset < rhs.hostByteOffset;
});
SmallVector<LaneLocalCopyRun, 8> laneRuns;
for (const FragmentAssemblyCopy& copy : sortedCopies) {
if (copy.lane < 0 || copy.lane >= static_cast<int64_t>(laneCount))
return failure();
if (!laneRuns.empty()) {
LaneLocalCopyRun& laneRun = laneRuns.back();
FragmentAssemblyCopyRun& run = laneRun.run;
if (run.source == copy.source && run.sourceType == copy.sourceType
&& run.hostTargetIndex == copy.hostTargetIndex && laneRun.lane == copy.lane && run.byteSize == copy.byteSize
&& run.sourceStartBytesByLane.size() == 1 && run.hostStartBytesByLane.size() == 1) {
int64_t previousSourceOffset = run.sourceStartBytesByLane.front() + (run.count - 1) * run.sourceStepBytes;
int64_t previousHostOffset = run.hostStartBytesByLane.front() + (run.count - 1) * run.hostStepBytes;
int64_t sourceDelta = copy.sourceByteOffset - previousSourceOffset;
int64_t hostDelta = copy.hostByteOffset - previousHostOffset;
if (run.count == 1) {
run.sourceStepBytes = sourceDelta;
run.hostStepBytes = hostDelta;
++run.count;
continue;
}
if (run.sourceStepBytes == sourceDelta && run.hostStepBytes == hostDelta) {
++run.count;
continue;
}
}
}
LaneLocalCopyRun laneRun;
laneRun.run.source = copy.source;
laneRun.run.sourceType = copy.sourceType;
laneRun.run.hostTargetIndex = copy.hostTargetIndex;
laneRun.run.count = 1;
laneRun.run.byteSize = copy.byteSize;
laneRun.run.sourceStartBytesByLane.push_back(copy.sourceByteOffset);
laneRun.run.hostStartBytesByLane.push_back(copy.hostByteOffset);
laneRun.lane = copy.lane;
laneRuns.push_back(std::move(laneRun));
}
SmallVector<FragmentAssemblyCopyRun, 8> mergedRuns;
for (const LaneLocalCopyRun& laneRun : laneRuns) {
size_t laneIndex = static_cast<size_t>(laneRun.lane);
auto mergedIt = llvm::find_if(mergedRuns, [&](const FragmentAssemblyCopyRun& run) {
return run.source == laneRun.run.source && run.sourceType == laneRun.run.sourceType
&& run.hostTargetIndex == laneRun.run.hostTargetIndex && run.count == laneRun.run.count
&& run.byteSize == laneRun.run.byteSize && run.sourceStepBytes == laneRun.run.sourceStepBytes
&& run.hostStepBytes == laneRun.run.hostStepBytes && laneIndex < run.sourceStartBytesByLane.size()
&& run.sourceStartBytesByLane[laneIndex] == std::numeric_limits<int64_t>::min();
});
if (mergedIt == mergedRuns.end()) {
FragmentAssemblyCopyRun merged = laneRun.run;
merged.sourceStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
merged.hostStartBytesByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
merged.sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
merged.hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
mergedRuns.push_back(std::move(merged));
continue;
}
mergedIt->sourceStartBytesByLane[laneIndex] = laneRun.run.sourceStartBytesByLane.front();
mergedIt->hostStartBytesByLane[laneIndex] = laneRun.run.hostStartBytesByLane.front();
}
for (const FragmentAssemblyCopyRun& run : mergedRuns) {
if (llvm::any_of(run.sourceStartBytesByLane,
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
return failure();
if (llvm::any_of(run.hostStartBytesByLane,
[](int64_t value) { return value == std::numeric_limits<int64_t>::min(); }))
return failure();
}
return mergedRuns;
}
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRun(OpBuilder& builder,
Location loc,
const FragmentAssemblyCopyRun& run,
mlir::Value hostTarget,
Operation* anchor,
std::optional<mlir::Value> laneArg,
mlir::Value baseHostOffset,
mlir::Value sourceRunStartDelta = {},
mlir::Value hostRunStartDelta = {}) {
auto sizeAttr = pim::getCheckedI32Attr(builder, anchor, run.byteSize, "fragment assembly host copy byte size");
if (failed(sizeAttr))
return failure();
mlir::Value hostStart;
mlir::Value sourceStart;
if (laneArg) {
hostStart = createIndexedOffset(builder, loc, *laneArg, run.hostStartBytesByLane);
sourceStart = createIndexedOffset(builder, loc, *laneArg, run.sourceStartBytesByLane);
} else {
hostStart = arith::ConstantIndexOp::create(builder, loc, run.hostStartBytesByLane.front());
sourceStart = arith::ConstantIndexOp::create(builder, loc, run.sourceStartBytesByLane.front());
}
if (hostRunStartDelta)
hostStart = arith::AddIOp::create(builder, loc, hostStart, hostRunStartDelta).getResult();
if (sourceRunStartDelta)
sourceStart = arith::AddIOp::create(builder, loc, sourceStart, sourceRunStartDelta).getResult();
if (baseHostOffset)
hostStart = arith::AddIOp::create(builder, loc, baseHostOffset, hostStart).getResult();
if (run.count == 1) {
return pim::PimMemCopyDevToHostOp::create(builder,
loc,
hostTarget.getType(),
hostStart,
sourceStart,
hostTarget,
run.source,
*sizeAttr)
.getOutput();
}
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, run.count);
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
FailureOr<NormalizedLoopResult> loop = buildNormalizedScfFor(
builder,
loc,
lowerBound,
upperBound,
step,
ValueRange {hostTarget},
[&](OpBuilder& loopBuilder,
Location bodyLoc,
mlir::Value flatIndex,
ValueRange iterArgs,
SmallVectorImpl<mlir::Value>& yielded) {
mlir::Value hostOffset = createSteppedOffset(loopBuilder, bodyLoc, hostStart, flatIndex, run.hostStepBytes);
mlir::Value sourceOffset =
createSteppedOffset(loopBuilder, bodyLoc, sourceStart, flatIndex, run.sourceStepBytes);
mlir::Value copied =
pim::PimMemCopyDevToHostOp::create(loopBuilder,
bodyLoc,
iterArgs.front().getType(),
hostOffset,
sourceOffset,
iterArgs.front(),
run.source,
*sizeAttr)
.getOutput();
yielded.push_back(copied);
return success();
});
if (failed(loop))
return failure();
return loop->results.front();
}
static FailureOr<mlir::Value> emitFragmentAssemblyCopyRunFamily(OpBuilder& builder,
Location loc,
const FragmentAssemblyCopyRunFamily& family,
mlir::Value hostTarget,
Operation* anchor,
std::optional<mlir::Value> laneArg,
mlir::Value baseHostOffset) {
if (family.sourceRunStartDeltas.size() == 1)
return emitFragmentAssemblyCopyRun(
builder, loc, family.prototype, hostTarget, anchor, laneArg, baseHostOffset);
mlir::Value lowerBound = arith::ConstantIndexOp::create(builder, loc, 0);
mlir::Value upperBound = arith::ConstantIndexOp::create(builder, loc, family.sourceRunStartDeltas.size());
mlir::Value step = arith::ConstantIndexOp::create(builder, loc, 1);
FailureOr<NormalizedLoopResult> outerLoop = buildNormalizedScfFor(
builder,
loc,
lowerBound,
upperBound,
step,
ValueRange {hostTarget},
[&](OpBuilder& loopBuilder,
Location bodyLoc,
mlir::Value runIndex,
ValueRange iterArgs,
SmallVectorImpl<mlir::Value>& yielded) {
mlir::Value sourceRunStartDelta =
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.sourceRunStartDeltas);
mlir::Value hostRunStartDelta =
createIndexedOffset(loopBuilder, bodyLoc, runIndex, family.hostRunStartDeltas);
FailureOr<mlir::Value> copied = emitFragmentAssemblyCopyRun(loopBuilder,
bodyLoc,
family.prototype,
iterArgs.front(),
anchor,
laneArg,
baseHostOffset,
sourceRunStartDelta,
hostRunStartDelta);
if (failed(copied))
return failure();
yielded.push_back(*copied);
return success();
});
if (failed(outerLoop))
return failure();
return outerLoop->results.front();
}
FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(IRRewriter& rewriter,
Location loc,
ArrayRef<FragmentAssemblyCopyRun> runs,
mlir::Value hostTarget,
Operation* anchor,
std::optional<mlir::Value> laneArg,
mlir::Value baseHostOffset) {
for (const FragmentAssemblyCopyRunFamily& family : groupFragmentAssemblyCopyRunFamilies(runs)) {
FailureOr<mlir::Value> updatedHostTarget =
emitFragmentAssemblyCopyRunFamily(rewriter, loc, family, hostTarget, anchor, laneArg, baseHostOffset);
if (failed(updatedHostTarget))
return failure();
hostTarget = *updatedHostTarget;
}
return hostTarget;
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,8 +1,14 @@
#pragma once #pragma once
#include <optional>
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/STLFunctionalExtras.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Value.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
@@ -59,6 +65,39 @@ forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)> llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
callback); callback);
struct FragmentAssemblyCopy {
mlir::Value source;
mlir::RankedTensorType sourceType;
unsigned hostTargetIndex = 0;
int64_t lane = 0;
int64_t sourceByteOffset = 0;
int64_t hostByteOffset = 0;
int64_t byteSize = 0;
};
struct FragmentAssemblyCopyRun {
mlir::Value source;
mlir::RankedTensorType sourceType;
unsigned hostTargetIndex = 0;
int64_t count = 0;
int64_t sourceStepBytes = 0;
int64_t hostStepBytes = 0;
int64_t byteSize = 0;
mlir::SmallVector<int64_t, 8> sourceStartBytesByLane;
mlir::SmallVector<int64_t, 8> hostStartBytesByLane;
};
mlir::FailureOr<mlir::SmallVector<FragmentAssemblyCopyRun, 8>>
groupFragmentAssemblyCopyRuns(llvm::ArrayRef<FragmentAssemblyCopy> copies, uint32_t laneCount = 1);
mlir::FailureOr<mlir::Value> emitFragmentAssemblyCopyRuns(mlir::IRRewriter& rewriter,
mlir::Location loc,
llvm::ArrayRef<FragmentAssemblyCopyRun> runs,
mlir::Value hostTarget,
mlir::Operation* anchor,
std::optional<mlir::Value> laneArg = std::nullopt,
mlir::Value baseHostOffset = {});
inline mlir::tensor::EmptyOp inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType()); return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -8,6 +9,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -30,22 +32,9 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op); pim::PimTransposeOp>(op);
} }
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
Location loc,
ShapedType destinationType,
ArrayRef<int64_t> fragmentOffsets) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
int64_t byteOffset = 0;
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
byteOffset += offset * strides[dim] * elementBytes;
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
}
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter, static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
spatial::SpatBlueprintOp blueprint, spatial::SpatBlueprintOp blueprint,
IRMapping& mapping) { IRMapping& mapping) {
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType()); auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result"); return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
@@ -77,56 +66,54 @@ static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
flatStrides))) flatStrides)))
return failure(); return failure();
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType); SmallVector<int64_t> hostStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<FragmentAssemblyCopy, 8> copies;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) { for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex]; int64_t operandIndex = operandIndices[fragmentIndex];
SmallVector<int64_t, 4> fragmentOffsets; SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1; SmallVector<int64_t, 4> fragmentSizes;
for (int64_t dim = 0; dim < rank; ++dim) { for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim; int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1) if (flatStrides[flatIndex] != 1)
return blueprint.emitOpError("fragment assembly lowering only supports unit strides"); return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]); fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex]; fragmentSizes.push_back(flatSizes[flatIndex]);
} }
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]); Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType()); auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape()) if (!sourceType || !sourceType.hasStaticShape())
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands"); return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
if (failed(forEachContiguousDestinationChunk(
resultType.getShape(),
fragmentOffsets,
fragmentSizes,
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
int64_t hostElementOffset = 0;
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
hostElementOffset += offset * hostStrides[dim];
int64_t fragmentBytes = FragmentAssemblyCopy copy;
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType())); copy.source = source;
auto sizeAttr = pim::getCheckedI32Attr(rewriter, copy.sourceType = sourceType;
blueprint.getOperation(), copy.sourceByteOffset =
fragmentBytes, (sourceOffsets[fragmentIndex] + relativeSourceOffset) * static_cast<int64_t>(elementSize);
"fragment assembly host copy size"); copy.hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
if (failed(sizeAttr)) copy.byteSize = chunkElements * static_cast<int64_t>(elementSize);
copies.push_back(copy);
return success();
})))
return failure(); return failure();
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
blueprint,
"fragment assembly device source offset");
if (failed(deviceSourceOffsetBytes))
return failure();
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
static_cast<int64_t>(*deviceSourceOffsetBytes));
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
blueprint.getLoc(),
currentOutput.getType(),
hostTargetOffset,
deviceSourceOffset,
currentOutput,
source,
*sizeAttr)
.getOutput();
} }
return currentOutput; Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
if (failed(runs))
return failure();
return emitFragmentAssemblyCopyRuns(
rewriter, blueprint.getLoc(), *runs, currentOutput, blueprint.getOperation());
} }
static void static void
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@@ -11,6 +12,7 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -638,6 +640,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
flatSizes, flatSizes,
flatStrides))) flatStrides)))
return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Failure;
SmallVector<FragmentAssemblyCopy, 8> copies;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) { for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber)) if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
continue; continue;
@@ -675,29 +678,27 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
failedChunk = true; failedChunk = true;
return failure(); return failure();
} }
auto sizeAttr = FragmentAssemblyCopy copy;
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size"); copy.source = storedValue;
if (failed(sizeAttr)) { copy.sourceType = sourceType;
failedChunk = true; copy.hostByteOffset = *hostOffset;
return failure(); copy.sourceByteOffset = *sourceOffset;
} copy.byteSize = *fragmentBytes;
copies.push_back(copy);
outputTensor =
pim::PimMemCopyDevToHostOp::create(rewriter,
blueprint.getLoc(),
outputTensor.getType(),
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
outputTensor,
storedValue,
*sizeAttr)
.getOutput();
return success(); return success();
}))) })))
failedChunk = true; failedChunk = true;
if (failedChunk) if (failedChunk)
return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Failure;
} }
FailureOr<SmallVector<FragmentAssemblyCopyRun, 8>> runs = groupFragmentAssemblyCopyRuns(copies);
if (failed(runs))
return ReturnPathLoweringResult::Failure;
FailureOr<Value> updatedOutput =
emitFragmentAssemblyCopyRuns(rewriter, blueprint.getLoc(), *runs, outputTensor, producerOp);
if (failed(updatedOutput))
return ReturnPathLoweringResult::Failure;
outputTensor = *updatedOutput;
markOpToRemove(blueprint.getOperation()); markOpToRemove(blueprint.getOperation());
} }
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
File diff suppressed because it is too large Load Diff