@@ -113,6 +113,19 @@ void verifyScheduledInputs(ComputeOpTy compute,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) {
|
||||
std::optional<StringRef> mode = reconciliator.getMode();
|
||||
if (!mode || *mode != "fragment_assembly")
|
||||
return;
|
||||
diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization");
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<func::ReturnOp,
|
||||
@@ -188,10 +201,14 @@ LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
|
||||
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||
verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics);
|
||||
}
|
||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||
verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics);
|
||||
}
|
||||
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||
return failure();
|
||||
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
|
||||
|
||||
@@ -1702,8 +1702,6 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) {
|
||||
state.outWidth);
|
||||
if (!tiling)
|
||||
return false;
|
||||
if (tiling->numChannelTiles > static_cast<int64_t>(crossbarCountInCore.getValue()))
|
||||
return false;
|
||||
|
||||
if (!state.hasBias)
|
||||
return true;
|
||||
|
||||
@@ -107,6 +107,7 @@ static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewr
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,151 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
return result.getUses().begin()->getOperandNumber();
|
||||
}
|
||||
|
||||
struct BatchFragmentAssemblyPlan {
|
||||
unsigned returnIndex = 0;
|
||||
int64_t localSourceElementOffset = 0;
|
||||
int64_t fragmentByteSize = 0;
|
||||
SmallVector<int64_t, 8> hostOffsetsByLane;
|
||||
};
|
||||
|
||||
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) {
|
||||
assert(!values.empty() && "expected lane-indexed values");
|
||||
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
||||
return getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
|
||||
if (values.size() >= 2) {
|
||||
int64_t step = values[1] - values[0];
|
||||
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
||||
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
||||
});
|
||||
if (arithmetic) {
|
||||
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
if (step == 0)
|
||||
return base;
|
||||
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
|
||||
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult();
|
||||
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult();
|
||||
}
|
||||
}
|
||||
|
||||
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
||||
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
|
||||
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
|
||||
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
|
||||
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
|
||||
}
|
||||
return selected;
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>>
|
||||
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
||||
SmallVector<BatchFragmentAssemblyPlan, 8> plans;
|
||||
if (!packedResultType.hasStaticShape() || laneCount == 0)
|
||||
return failure();
|
||||
|
||||
int64_t packedElementCount = packedResultType.getNumElements();
|
||||
if (packedElementCount % static_cast<int64_t>(laneCount) != 0)
|
||||
return failure();
|
||||
int64_t payloadElementCount = packedElementCount / static_cast<int64_t>(laneCount);
|
||||
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
|
||||
|
||||
for (OpOperand& use : result.getUses()) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
|
||||
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
|
||||
return failure();
|
||||
std::optional<StringRef> mode = reconciliator.getMode();
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
||||
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
||||
return failure();
|
||||
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
|
||||
return failure();
|
||||
|
||||
unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
|
||||
auto hostResultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
if (!hostResultType || !hostResultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||
int64_t rank = hostResultType.getRank();
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
||||
rank,
|
||||
fragmentOperands.size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
|
||||
int64_t lane = sourceElementOffset / payloadElementCount;
|
||||
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
|
||||
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return failure();
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||
}
|
||||
|
||||
if (failed(forEachContiguousDestinationChunk(
|
||||
hostResultType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
int64_t hostElementOffset = 0;
|
||||
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
||||
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||
hostElementOffset += offset * hostStrides[dim];
|
||||
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
|
||||
|
||||
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) {
|
||||
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset
|
||||
&& plan.fragmentByteSize == fragmentByteSize;
|
||||
});
|
||||
if (planIt == plans.end()) {
|
||||
BatchFragmentAssemblyPlan plan;
|
||||
plan.returnIndex = returnIndex;
|
||||
plan.localSourceElementOffset = chunkSourceOffset;
|
||||
plan.fragmentByteSize = fragmentByteSize;
|
||||
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||
plans.push_back(std::move(plan));
|
||||
return success();
|
||||
}
|
||||
|
||||
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||
return success();
|
||||
})))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (const BatchFragmentAssemblyPlan& plan : plans)
|
||||
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
|
||||
return failure();
|
||||
return plans;
|
||||
}
|
||||
|
||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||
if (scale == 1)
|
||||
return base;
|
||||
@@ -250,6 +395,10 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
||||
if (!sourceOffsetsAttr)
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
@@ -257,21 +406,25 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
||||
rank,
|
||||
fragmentOperands.size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||
@@ -279,20 +432,21 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||
reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||
if (failed(extractOffsets))
|
||||
return failure();
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
@@ -351,16 +505,29 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
if (succeeded(returnOperandIndex)) {
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch publication lowering requires static ranked tensor results");
|
||||
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans =
|
||||
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
||||
if (failed(fragmentAssemblyPlans))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -446,9 +613,44 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
|
||||
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
|
||||
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size()
|
||||
&& !fragmentAssemblyPlansByResult[resultIndex].empty();
|
||||
if (!hasDirectReturn && !hasFragmentAssembly)
|
||||
continue;
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
|
||||
if (hasFragmentAssembly) {
|
||||
BlockArgument laneArg = coreBatchOp.getLaneArgument();
|
||||
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
|
||||
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
|
||||
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
|
||||
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) {
|
||||
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc());
|
||||
auto sizeAttr = pim::getCheckedI32Attr(
|
||||
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
Value hostTargetOffset =
|
||||
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
|
||||
Value deviceSourceOffset = getOrCreateIndexConstant(
|
||||
rewriter, coreBatchOp.getOperation(),
|
||||
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
|
||||
outputTensor =
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
outputTensor.getType(),
|
||||
hostTargetOffset,
|
||||
deviceSourceOffset,
|
||||
outputTensor,
|
||||
mappedSource,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
if (auto reconciliator =
|
||||
@@ -467,7 +669,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
||||
}
|
||||
}
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
@@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
|
||||
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||
}
|
||||
|
||||
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
|
||||
int64_t resultRank,
|
||||
size_t operandCount,
|
||||
ArrayRef<int64_t> operandIndices,
|
||||
ArrayRef<int64_t> sourceOffsets,
|
||||
ArrayRef<int64_t> flatOffsets,
|
||||
ArrayRef<int64_t> flatSizes,
|
||||
ArrayRef<int64_t> flatStrides) {
|
||||
if (operandIndices.size() != sourceOffsets.size())
|
||||
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
|
||||
if (flatOffsets.size() != flatSizes.size())
|
||||
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
||||
if (flatStrides.size() != flatOffsets.size())
|
||||
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
||||
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
||||
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
||||
|
||||
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
if (sourceOffsets[fragmentIndex] < 0)
|
||||
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||
indices[dim] = flatIndex % shape[dim];
|
||||
flatIndex /= shape[dim];
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t, 4>>
|
||||
getStaticSliceOffsetsForElementOffset(Operation* anchor,
|
||||
ShapedType sourceType,
|
||||
ArrayRef<int64_t> fragmentShape,
|
||||
int64_t sourceElementOffset,
|
||||
StringRef fieldName) {
|
||||
if (!sourceType.hasStaticShape())
|
||||
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
|
||||
if (sourceElementOffset < 0)
|
||||
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
|
||||
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
|
||||
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
|
||||
|
||||
int64_t sourceElementCount = sourceType.getNumElements();
|
||||
int64_t fragmentElementCount = 1;
|
||||
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||
if (fragmentShape[dim] < 0)
|
||||
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
|
||||
fragmentElementCount *= fragmentShape[dim];
|
||||
}
|
||||
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
|
||||
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
|
||||
|
||||
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
|
||||
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
|
||||
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
|
||||
}
|
||||
return sliceOffsets;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
||||
ArrayRef<int64_t> baseOffsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
|
||||
int64_t rank = static_cast<int64_t>(sizes.size());
|
||||
int64_t suffixStart = rank - 1;
|
||||
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
|
||||
--suffixStart;
|
||||
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
|
||||
suffixStart = 0;
|
||||
else
|
||||
++suffixStart;
|
||||
|
||||
int64_t chunkElements = 1;
|
||||
for (int64_t dim = suffixStart; dim < rank; ++dim)
|
||||
chunkElements *= sizes[dim];
|
||||
|
||||
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
|
||||
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
|
||||
int64_t sourceChunkOrdinal = 0;
|
||||
|
||||
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
|
||||
if (dim == static_cast<int64_t>(prefixExtents.size())) {
|
||||
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
|
||||
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
|
||||
chunkOffsets[prefixDim] += current[prefixDim];
|
||||
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
|
||||
return failure();
|
||||
++sourceChunkOrdinal;
|
||||
return success();
|
||||
}
|
||||
|
||||
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
|
||||
current[dim] = index;
|
||||
if (failed(visit(visit, dim + 1)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (prefixExtents.empty())
|
||||
return callback(baseOffsets, 0, chunkElements);
|
||||
return visit(visit, 0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
class SpatReconciliatorOp;
|
||||
}
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
@@ -29,6 +36,29 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
|
||||
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||
|
||||
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator,
|
||||
int64_t resultRank,
|
||||
size_t operandCount,
|
||||
llvm::ArrayRef<int64_t> operandIndices,
|
||||
llvm::ArrayRef<int64_t> sourceOffsets,
|
||||
llvm::ArrayRef<int64_t> flatOffsets,
|
||||
llvm::ArrayRef<int64_t> flatSizes,
|
||||
llvm::ArrayRef<int64_t> flatStrides);
|
||||
|
||||
mlir::FailureOr<mlir::SmallVector<int64_t, 4>>
|
||||
getStaticSliceOffsetsForElementOffset(mlir::Operation* anchor,
|
||||
mlir::ShapedType sourceType,
|
||||
llvm::ArrayRef<int64_t> fragmentShape,
|
||||
int64_t sourceElementOffset,
|
||||
llvm::StringRef fieldName);
|
||||
|
||||
mlir::LogicalResult
|
||||
forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
|
||||
llvm::ArrayRef<int64_t> baseOffsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
|
||||
callback);
|
||||
|
||||
inline mlir::tensor::EmptyOp
|
||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||
|
||||
@@ -52,11 +52,14 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
||||
|
||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
||||
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
|
||||
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|
||||
|| !fragmentStridesAttr)
|
||||
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
@@ -64,13 +67,19 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
||||
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
||||
rank,
|
||||
fragmentOperands.size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return failure();
|
||||
|
||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
@@ -96,11 +105,16 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
|
||||
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
||||
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
||||
reconciliator,
|
||||
"fragment assembly device source offset");
|
||||
if (failed(deviceSourceOffsetBytes))
|
||||
return failure();
|
||||
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
packedFragmentOrdinal * fragmentBytes);
|
||||
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
||||
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
currentOutput.getType(),
|
||||
|
||||
@@ -47,11 +47,13 @@ struct LowerFragmentAssemblyReconciliatorPattern
|
||||
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
|
||||
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||
@@ -59,23 +61,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
|
||||
|
||||
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
||||
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(
|
||||
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
|
||||
return failure();
|
||||
|
||||
Value currentOutput =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
||||
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||
return op.emitOpError("fragment assembly operand index is out of range");
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1)
|
||||
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentElements *= flatSizes[flatIndex];
|
||||
}
|
||||
|
||||
Value source = fragmentOperands[operandIndex];
|
||||
@@ -83,20 +83,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||
|
||||
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
fragmentShape.reserve(rank);
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||
|
||||
Value fragment = source;
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||
if (failed(extractOffsets))
|
||||
return failure();
|
||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
source,
|
||||
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||
getUnitStrides(rewriter, rank));
|
||||
}
|
||||
|
||||
@@ -149,6 +149,40 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
||||
};
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>>
|
||||
analyzeTopLevelFragmentAssemblyUses(Value value) {
|
||||
SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4> uses;
|
||||
for (OpOperand& use : value.getUses()) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
|
||||
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
|
||||
return failure();
|
||||
std::optional<StringRef> mode = reconciliator.getMode();
|
||||
if (!mode || *mode != "fragment_assembly")
|
||||
return failure();
|
||||
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
|
||||
return failure();
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
||||
resultType.getRank(),
|
||||
fragmentOperands.size(),
|
||||
*operandIndicesAttr,
|
||||
*sourceOffsetsAttr,
|
||||
reconciliator.getFragmentOffsets(),
|
||||
reconciliator.getFragmentSizes(),
|
||||
*stridesAttr)))
|
||||
return failure();
|
||||
uses.emplace_back(reconciliator, use.getOperandNumber());
|
||||
}
|
||||
return uses;
|
||||
}
|
||||
|
||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation* op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
@@ -559,6 +593,116 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>> fragmentAssemblyUses =
|
||||
analyzeTopLevelFragmentAssemblyUses(producedValue);
|
||||
if (succeeded(fragmentAssemblyUses)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape()) {
|
||||
producerOp->emitOpError("fragment assembly publication requires a static ranked tensor source");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||
for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) {
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
||||
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
|
||||
reconciliator.emitOpError(
|
||||
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
|
||||
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
|
||||
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
if (!outputType || !resultType || !resultType.hasStaticShape()) {
|
||||
reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||
int64_t rank = resultType.getRank();
|
||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
||||
rank,
|
||||
1 + reconciliator.getFragments().size(),
|
||||
operandIndices,
|
||||
sourceOffsets,
|
||||
flatOffsets,
|
||||
flatSizes,
|
||||
flatStrides)))
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
|
||||
continue;
|
||||
|
||||
SmallVector<int64_t, 4> fragmentOffsets;
|
||||
SmallVector<int64_t, 4> fragmentSizes;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||
if (flatStrides[flatIndex] != 1) {
|
||||
reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||
}
|
||||
|
||||
bool failedChunk = false;
|
||||
if (failed(forEachContiguousDestinationChunk(
|
||||
outputType.getShape(),
|
||||
fragmentOffsets,
|
||||
fragmentSizes,
|
||||
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||
auto hostOffset =
|
||||
getCheckedByteOffset(computeFlatElementIndex(chunkOffsets, outputType.getShape()),
|
||||
elementSize,
|
||||
producerOp,
|
||||
"fragment assembly host offset");
|
||||
auto sourceOffset = getCheckedByteOffset(sourceOffsets[fragmentIndex] + relativeSourceOffset,
|
||||
elementSize,
|
||||
producerOp,
|
||||
"fragment assembly source offset");
|
||||
auto fragmentBytes =
|
||||
getCheckedByteOffset(chunkElements, elementSize, producerOp, "fragment assembly host copy byte size");
|
||||
if (failed(hostOffset) || failed(sourceOffset) || failed(fragmentBytes)) {
|
||||
failedChunk = true;
|
||||
return failure();
|
||||
}
|
||||
auto sizeAttr =
|
||||
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size");
|
||||
if (failed(sizeAttr)) {
|
||||
failedChunk = true;
|
||||
return failure();
|
||||
}
|
||||
|
||||
outputTensor =
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
reconciliator.getLoc(),
|
||||
outputTensor.getType(),
|
||||
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
||||
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
*sizeAttr)
|
||||
.getOutput();
|
||||
return success();
|
||||
})))
|
||||
failedChunk = true;
|
||||
if (failedChunk)
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
markOpToRemove(reconciliator.getOperation());
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
auto storedByteSize =
|
||||
@@ -669,6 +813,16 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
||||
std::optional<StringRef> mode = reconciliator.getMode();
|
||||
if (mode && *mode == "fragment_assembly") {
|
||||
markOpToRemove(reconciliator.getOperation());
|
||||
for (Value operand : reconciliator->getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
|
||||
@@ -245,6 +245,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
||||
StrAttr:$indexMap,
|
||||
OptionalAttr<StrAttr>:$mode,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$fragmentOperandIndices,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$fragmentSourceOffsets,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$fragmentStrides,
|
||||
OptionalAttr<StrAttr>:$conflictPolicy,
|
||||
OptionalAttr<StrAttr>:$coveragePolicy
|
||||
|
||||
@@ -483,21 +483,28 @@ LogicalResult SpatReconciliatorOp::verify() {
|
||||
return failure();
|
||||
if (!getFragments().empty())
|
||||
return emitError("legacy reconciliator does not accept extra fragment operands");
|
||||
if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr())
|
||||
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|
||||
|| getCoveragePolicyAttr())
|
||||
return emitError("legacy reconciliator does not accept fragment assembly attributes");
|
||||
return success();
|
||||
}
|
||||
|
||||
auto stridesAttr = getFragmentStridesAttr();
|
||||
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
|
||||
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
|
||||
if (!operandIndicesAttr)
|
||||
return emitError("fragment assembly reconciliator requires fragment operand indices");
|
||||
if (!sourceOffsetsAttr)
|
||||
return emitError("fragment assembly reconciliator requires fragment source offsets");
|
||||
if (!stridesAttr)
|
||||
return emitError("fragment assembly reconciliator requires fragment strides");
|
||||
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
|
||||
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
|
||||
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
|
||||
if (strides.size() != offsets.size())
|
||||
return emitError("fragment stride and offset arrays must have the same length");
|
||||
if (sourceOffsets.size() != operandIndices.size())
|
||||
return emitError("fragment source offset count must match fragment operand index count");
|
||||
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
|
||||
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
|
||||
if (getConflictPolicy() != "disjoint")
|
||||
@@ -519,11 +526,21 @@ LogicalResult SpatReconciliatorOp::verify() {
|
||||
|
||||
SmallVector<std::pair<SmallVector<int64_t, 4>, SmallVector<int64_t, 4>>, 8> slices;
|
||||
slices.reserve(static_cast<size_t>(fragmentCount));
|
||||
SmallVector<SmallVector<SmallVector<int64_t, 4>, 4>, 8> sizesByOperand(static_cast<size_t>(operandCount));
|
||||
SmallVector<int64_t, 8> fragmentCountsByOperand(static_cast<size_t>(operandCount), 0);
|
||||
auto expandFlatElementIndex = [](int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||
indices[dim] = flatIndex % shape[dim];
|
||||
flatIndex /= shape[dim];
|
||||
}
|
||||
return indices;
|
||||
};
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||
if (operandIndex < 0 || operandIndex >= operandCount)
|
||||
return emitError("fragment assembly operand index is out of range");
|
||||
if (sourceOffsets[fragmentIndex] < 0)
|
||||
return emitError("fragment assembly source offsets must be nonnegative");
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
@@ -541,7 +558,17 @@ LogicalResult SpatReconciliatorOp::verify() {
|
||||
fragmentSizes.push_back(sizes[flatIndex]);
|
||||
}
|
||||
|
||||
sizesByOperand[static_cast<size_t>(operandIndex)].push_back(fragmentSizes);
|
||||
++fragmentCountsByOperand[static_cast<size_t>(operandIndex)];
|
||||
int64_t fragmentElements = 1;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
fragmentElements *= fragmentSizes[dim];
|
||||
if (sourceOffsets[fragmentIndex] + fragmentElements > operandType.getNumElements())
|
||||
return emitError("fragment assembly source offset exceeds the operand bounds");
|
||||
SmallVector<int64_t, 4> sourceSliceOffsets =
|
||||
expandFlatElementIndex(sourceOffsets[fragmentIndex], operandType.getShape());
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
if (sourceSliceOffsets[dim] + fragmentSizes[dim] > operandType.getDimSize(dim))
|
||||
return emitError("fragment assembly source offset must describe a valid unit-stride slice");
|
||||
|
||||
for (const auto& [existingOffsets, existingSizes] : slices) {
|
||||
bool overlaps = true;
|
||||
@@ -562,28 +589,8 @@ LogicalResult SpatReconciliatorOp::verify() {
|
||||
}
|
||||
|
||||
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
|
||||
if (sizesByOperand[static_cast<size_t>(operandIndex)].empty())
|
||||
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
|
||||
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
|
||||
|
||||
auto operandType = cast<RankedTensorType>(operands[operandIndex].getType());
|
||||
ArrayRef<int64_t> operandShape = operandType.getShape();
|
||||
auto& fragmentShapes = sizesByOperand[static_cast<size_t>(operandIndex)];
|
||||
if (fragmentShapes.size() == 1) {
|
||||
if (!llvm::equal(operandShape, fragmentShapes.front()))
|
||||
return emitError("single-fragment reconciliator operand shape must match declared fragment size");
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> fragmentShape = fragmentShapes.front();
|
||||
for (ArrayRef<int64_t> otherShape : fragmentShapes)
|
||||
if (!llvm::equal(fragmentShape, otherShape))
|
||||
return emitError("packed reconciliator operand requires equal fragment sizes per operand");
|
||||
if (llvm::equal(operandShape, fragmentShape))
|
||||
continue;
|
||||
if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front()))
|
||||
return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions");
|
||||
if (operandShape.front() != static_cast<int64_t>(fragmentShapes.size()) * fragmentShape.front())
|
||||
return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size");
|
||||
}
|
||||
|
||||
if (getCoveragePolicy() == "complete") {
|
||||
|
||||
+197
-286
@@ -194,6 +194,7 @@ struct MaterializedClass {
|
||||
SmallVector<Value, 8> weights;
|
||||
SmallVector<Value, 8> inputs;
|
||||
SmallVector<Value, 4> hostOutputs;
|
||||
DenseMap<Value, unsigned> publicationOutputToResultIndex;
|
||||
DenseMap<Value, BlockArgument> weightArgs;
|
||||
DenseMap<Value, BlockArgument> inputArgs;
|
||||
DenseMap<Value, unsigned> hostOutputToResultIndex;
|
||||
@@ -307,11 +308,9 @@ struct PendingProjectedHostOutputFragment {
|
||||
Value originalOutput;
|
||||
ClassId sourceClass = 0;
|
||||
ProducerKey producerKey;
|
||||
Value operand;
|
||||
RankedTensorType operandType;
|
||||
RankedTensorType fragmentType;
|
||||
int64_t packedFragmentIndex = -1;
|
||||
int64_t currentLane = -1;
|
||||
Value publicationValue;
|
||||
int64_t sourceFragmentOrdinal = 0;
|
||||
int64_t sourceElementOffset = 0;
|
||||
SmallVector<int64_t, 4> offsets;
|
||||
SmallVector<int64_t, 4> sizes;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
@@ -379,6 +378,9 @@ LogicalResult localizeCapturesInClonedOp(MaterializerState& state,
|
||||
Operation& clonedOp,
|
||||
IRMapping* mapper = nullptr);
|
||||
LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass);
|
||||
void createDim0ParallelInsertSlice(
|
||||
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset);
|
||||
Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc);
|
||||
bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch,
|
||||
const AffineProjectedInputSliceMatch& match,
|
||||
ProducerKey producer,
|
||||
@@ -627,6 +629,31 @@ ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, Co
|
||||
return logicalInstance;
|
||||
}
|
||||
|
||||
FailureOr<unsigned>
|
||||
getPublicationLaneForProducerKey(MaterializerState& state, const MaterializedClass& sourceClass, ProducerKey key) {
|
||||
if (!sourceClass.isBatch)
|
||||
return 0;
|
||||
|
||||
ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, key.instance);
|
||||
auto cpuIt = state.schedule.computeToCpuMap.find(scheduledProducer);
|
||||
if (cpuIt == state.schedule.computeToCpuMap.end()) {
|
||||
sourceClass.op->emitError("projected packed host publication could not resolve the producer CPU for a publication lane")
|
||||
<< " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount
|
||||
<< " resultIndex=" << key.resultIndex;
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto laneIt = sourceClass.cpuToLane.find(cpuIt->second);
|
||||
if (laneIt == sourceClass.cpuToLane.end()) {
|
||||
sourceClass.op->emitError("projected packed host publication could not map a producer key to a publication lane")
|
||||
<< " cpu=" << cpuIt->second << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount
|
||||
<< " resultIndex=" << key.resultIndex;
|
||||
return failure();
|
||||
}
|
||||
|
||||
return laneIt->second;
|
||||
}
|
||||
|
||||
SmallVector<ProducerKey, 4>
|
||||
collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
|
||||
// Destination collection works in the materializer's logical one-lane key domain.
|
||||
@@ -1043,6 +1070,7 @@ LogicalResult collectHostOutputs(MaterializerState& state) {
|
||||
for (MaterializedClass& materializedClass : state.classes) {
|
||||
materializedClass.hostOutputs.clear();
|
||||
materializedClass.hostOutputToResultIndex.clear();
|
||||
materializedClass.publicationOutputToResultIndex.clear();
|
||||
}
|
||||
state.hostOutputOwners.clear();
|
||||
|
||||
@@ -1150,48 +1178,6 @@ void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
|
||||
state.rewriter.setInsertionPointToEnd(&funcBlock);
|
||||
}
|
||||
|
||||
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
|
||||
DenseSet<CpuId> usedCpus;
|
||||
for (const auto& [cpu, _] : state.cpuToClass)
|
||||
usedCpus.insert(cpu);
|
||||
|
||||
CpuId assemblyCpu = 0;
|
||||
while (usedCpus.contains(assemblyCpu))
|
||||
++assemblyCpu;
|
||||
|
||||
setInsertionPointForNewMaterializedOp(state);
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
|
||||
|
||||
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
|
||||
compute.getProperties().setOperandSegmentSizes({0, 0});
|
||||
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
|
||||
if (failed(coreIdAttr))
|
||||
return failure();
|
||||
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
|
||||
|
||||
Block* body = state.rewriter.createBlock(&compute.getBody());
|
||||
state.rewriter.setInsertionPointToEnd(body);
|
||||
Value placeholder =
|
||||
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
|
||||
state.rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
MaterializedClass materializedClass;
|
||||
materializedClass.id = state.classes.size();
|
||||
materializedClass.cpus.push_back(assemblyCpu);
|
||||
materializedClass.op = compute.getOperation();
|
||||
materializedClass.body = body;
|
||||
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
|
||||
materializedClass.hostOutputs.push_back(originalOutput);
|
||||
state.cpuToClass[assemblyCpu] = materializedClass.id;
|
||||
state.hostOutputOwners[originalOutput] = materializedClass.id;
|
||||
state.classes.push_back(std::move(materializedClass));
|
||||
return state.classes.back().id;
|
||||
}
|
||||
|
||||
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
|
||||
auto it = materializedClass.weightArgs.find(weight);
|
||||
if (it != materializedClass.weightArgs.end())
|
||||
@@ -1226,13 +1212,125 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
||||
materializedClass.inputArgs[input] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
if (auto compute = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op)) {
|
||||
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
|
||||
assert(arg && "expected compute_batch body while inserting an input argument");
|
||||
materializedClass.inputArgs[input] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
|
||||
auto compute = cast<SpatScheduledComputeBatch>(materializedClass.op);
|
||||
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
|
||||
assert(arg && "expected compute_batch body while inserting an input argument");
|
||||
materializedClass.inputArgs[input] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
|
||||
void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state,
|
||||
Operation* oldOwner,
|
||||
Operation* newOwner) {
|
||||
if (!oldOwner || oldOwner == newOwner)
|
||||
return;
|
||||
|
||||
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) {
|
||||
auto publicationResult = dyn_cast_or_null<OpResult>(fragment.publicationValue);
|
||||
if (!publicationResult || publicationResult.getOwner() != oldOwner)
|
||||
publicationResult = OpResult();
|
||||
else
|
||||
fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber());
|
||||
|
||||
if (auto originalResult = dyn_cast_or_null<OpResult>(fragment.originalOutput); originalResult
|
||||
&& originalResult.getOwner() == oldOwner) {
|
||||
fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber());
|
||||
}
|
||||
|
||||
if (fragment.producerKey.instance.op == oldOwner)
|
||||
fragment.producerKey.instance.op = newOwner;
|
||||
}
|
||||
llvm_unreachable("Cannot reach here");
|
||||
}
|
||||
|
||||
FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
|
||||
MaterializedClass& materializedClass,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
||||
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
||||
return materializedClass.op->getResult(existing->second);
|
||||
|
||||
auto compute = dyn_cast<SpatScheduledCompute>(materializedClass.op);
|
||||
if (!compute)
|
||||
return materializedClass.op->emitError("scalar publication result requires spat.scheduled_compute owner");
|
||||
|
||||
auto payloadType = dyn_cast<RankedTensorType>(payload.getType());
|
||||
if (!payloadType || !payloadType.hasStaticShape())
|
||||
return materializedClass.op->emitError("scalar publication result requires static ranked tensor payload");
|
||||
|
||||
FailureOr<std::tuple<OpResult, SpatScheduledCompute>> inserted =
|
||||
compute.insertOutput(state.rewriter, compute.getNumResults(), payloadType, loc);
|
||||
if (failed(inserted))
|
||||
return materializedClass.op->emitError("failed to append scalar publication result");
|
||||
|
||||
Operation* oldOp = materializedClass.op;
|
||||
auto [result, newCompute] = *inserted;
|
||||
materializedClass.op = newCompute.getOperation();
|
||||
materializedClass.body = &newCompute.getBody().front();
|
||||
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
|
||||
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
||||
|
||||
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
|
||||
if (!yieldOp)
|
||||
return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result");
|
||||
state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); });
|
||||
return result;
|
||||
}
|
||||
|
||||
FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
|
||||
MaterializedClass& materializedClass,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
||||
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
||||
return materializedClass.op->getResult(existing->second);
|
||||
|
||||
auto batch = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op);
|
||||
if (!batch)
|
||||
return materializedClass.op->emitError("batch publication result requires spat.scheduled_compute_batch owner");
|
||||
|
||||
auto payloadType = dyn_cast<RankedTensorType>(payload.getType());
|
||||
if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0)
|
||||
return materializedClass.op->emitError(
|
||||
"batch publication result requires a static ranked tensor payload with rank > 0");
|
||||
|
||||
SmallVector<int64_t, 4> publishedShape(payloadType.getShape());
|
||||
publishedShape[0] *= static_cast<int64_t>(materializedClass.cpus.size());
|
||||
auto publishedType =
|
||||
RankedTensorType::get(publishedShape, payloadType.getElementType(), payloadType.getEncoding());
|
||||
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>> inserted =
|
||||
batch.insertOutput(state.rewriter, batch.getNumResults(), publishedType, loc);
|
||||
if (failed(inserted))
|
||||
return materializedClass.op->emitError("failed to append batch publication result");
|
||||
|
||||
Operation* oldOp = materializedClass.op;
|
||||
auto [result, outputArg, newBatch] = *inserted;
|
||||
materializedClass.op = newBatch.getOperation();
|
||||
materializedClass.body = &newBatch.getBody().front();
|
||||
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
|
||||
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
||||
|
||||
auto inParallelOp = dyn_cast<SpatInParallelOp>(materializedClass.body->getTerminator());
|
||||
auto laneArg = newBatch.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return materializedClass.op->emitError("batch publication result requires a lane argument");
|
||||
if (!inParallelOp) {
|
||||
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 0)
|
||||
return materializedClass.op->emitError(
|
||||
"batch publication result requires either spat.in_parallel or an empty spat.yield terminator");
|
||||
state.rewriter.setInsertionPoint(yieldOp);
|
||||
inParallelOp = SpatInParallelOp::create(state.rewriter, loc);
|
||||
state.rewriter.eraseOp(yieldOp);
|
||||
}
|
||||
|
||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
Value firstOffset =
|
||||
scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc);
|
||||
createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset);
|
||||
return result;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -5520,6 +5618,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<Value> publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc);
|
||||
if (failed(publicationResult))
|
||||
return failure();
|
||||
|
||||
int64_t fragmentElementCount = fragmentType.getNumElements();
|
||||
|
||||
for (auto [runIndex, slot] : llvm::enumerate(run)) {
|
||||
if (slot.peers.size() != 1) {
|
||||
sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots");
|
||||
@@ -5553,11 +5657,9 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
||||
originalOutput,
|
||||
sourceClass.id,
|
||||
ProducerKey {peer, resultIndex},
|
||||
packed,
|
||||
cast<RankedTensorType>(packed.getType()),
|
||||
fragmentType,
|
||||
static_cast<int64_t>(runIndex),
|
||||
*publicationResult,
|
||||
static_cast<int64_t>(runIndex),
|
||||
static_cast<int64_t>(runIndex) * fragmentElementCount,
|
||||
SmallVector<int64_t, 4>(*offsets),
|
||||
SmallVector<int64_t, 4>(*sizes),
|
||||
SmallVector<int64_t, 4>(*strides),
|
||||
@@ -5609,7 +5711,10 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
||||
if (fragmentType == originalOutput.getType())
|
||||
return false;
|
||||
|
||||
bool operandIsDim0Packed = false;
|
||||
FailureOr<Value> publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc);
|
||||
if (failed(publicationResult))
|
||||
return failure();
|
||||
|
||||
if (packedType != fragmentType) {
|
||||
if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast<int64_t>(keys.size()) != 0)
|
||||
return sourceClass.op->emitError(
|
||||
@@ -5622,9 +5727,16 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
||||
return sourceClass.op->emitError(
|
||||
"projected packed host publication fragment shape does not match projected slice size")
|
||||
<< " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size();
|
||||
operandIsDim0Packed = true;
|
||||
}
|
||||
|
||||
int64_t payloadElementCount = packedType.getNumElements();
|
||||
int64_t fragmentElementCount = fragmentType.getNumElements();
|
||||
int64_t fragmentsPerPublishedPayload = payloadElementCount / fragmentElementCount;
|
||||
if (fragmentsPerPublishedPayload <= 0 || static_cast<int64_t>(keys.size()) % fragmentsPerPublishedPayload != 0)
|
||||
return sourceClass.op->emitOpError(
|
||||
"projected packed host publication requires a deterministic publication packing layout")
|
||||
<< " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size();
|
||||
|
||||
for (auto [fragmentIndex, key] : llvm::enumerate(keys)) {
|
||||
if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1)
|
||||
return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result");
|
||||
@@ -5642,15 +5754,19 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
||||
return sourceClass.op->emitError(
|
||||
"projected packed host publication requires one operand to map to a consistent fragment shape");
|
||||
|
||||
FailureOr<unsigned> publishedLaneIndex = getPublicationLaneForProducerKey(state, sourceClass, key);
|
||||
if (failed(publishedLaneIndex))
|
||||
return failure();
|
||||
int64_t localFragmentOffsetWithinPublishedPayload =
|
||||
(static_cast<int64_t>(fragmentIndex) % fragmentsPerPublishedPayload) * fragmentElementCount;
|
||||
|
||||
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
||||
originalOutput,
|
||||
sourceClass.id,
|
||||
key,
|
||||
packed,
|
||||
packedType,
|
||||
fragmentType,
|
||||
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
|
||||
*publicationResult,
|
||||
static_cast<int64_t>(fragmentIndex),
|
||||
static_cast<int64_t>(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload,
|
||||
SmallVector<int64_t, 4>(*offsets),
|
||||
SmallVector<int64_t, 4>(*sizes),
|
||||
SmallVector<int64_t, 4>(*strides),
|
||||
@@ -5678,24 +5794,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
< reinterpret_cast<uintptr_t>(rhs.getAsOpaquePointer());
|
||||
});
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
||||
if (!returnOp)
|
||||
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
||||
|
||||
for (Value originalOutput : outputs) {
|
||||
auto ownerIt = state.hostOutputOwners.find(originalOutput);
|
||||
if (ownerIt == state.hostOutputOwners.end()) {
|
||||
Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation();
|
||||
return anchor->emitError("missing host owner for projected host output fragments");
|
||||
}
|
||||
|
||||
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
|
||||
|
||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return ownerClass->op->emitError("projected host output must have static ranked tensor type");
|
||||
return state.func.emitError("projected host output must have static ranked tensor type");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
||||
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
|
||||
const PendingProjectedHostOutputFragment* rhs) {
|
||||
if (lhs->sourceLane != rhs->sourceLane)
|
||||
return lhs->sourceLane < rhs->sourceLane;
|
||||
if (lhs->publicationValue != rhs->publicationValue)
|
||||
return reinterpret_cast<uintptr_t>(lhs->publicationValue.getAsOpaquePointer())
|
||||
< reinterpret_cast<uintptr_t>(rhs->publicationValue.getAsOpaquePointer());
|
||||
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
||||
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
||||
if (lhs->sourceClass != rhs->sourceClass)
|
||||
return lhs->sourceClass < rhs->sourceClass;
|
||||
return std::lexicographical_compare(lhs->offsets.begin(),
|
||||
@@ -5704,240 +5819,36 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
||||
rhs->offsets.end());
|
||||
});
|
||||
|
||||
bool allFromSameSourceClass =
|
||||
llvm::all_of(fragments, [&](const PendingProjectedHostOutputFragment* fragment) {
|
||||
return fragment->sourceClass == fragments.front()->sourceClass;
|
||||
});
|
||||
if (allFromSameSourceClass) {
|
||||
ownerClass = &state.classes[fragments.front()->sourceClass];
|
||||
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
||||
} else {
|
||||
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
|
||||
goto owner_selected;
|
||||
FailureOr<ClassId> createdOwner =
|
||||
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
|
||||
if (failed(createdOwner))
|
||||
return failure();
|
||||
ownerClass = &state.classes[*createdOwner];
|
||||
}
|
||||
owner_selected:
|
||||
|
||||
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
|
||||
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
|
||||
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
|
||||
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
|
||||
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
|
||||
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
|
||||
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
|
||||
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
|
||||
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
|
||||
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
|
||||
if (failed(sourceProjection) || !sourceLaneArg)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly requires the source batch projection metadata");
|
||||
|
||||
auto outputArg = batch.getOutputArgument(resultIt->second);
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
if (!outputArg || !laneArg)
|
||||
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
|
||||
|
||||
if (fragments.size() != ownerClass->cpus.size())
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects exactly one fragment per materialized lane");
|
||||
|
||||
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
|
||||
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
|
||||
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
|
||||
if (fragmentsByLane[currentLane])
|
||||
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
|
||||
fragmentsByLane[currentLane] = fragmentRecord;
|
||||
}
|
||||
|
||||
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
|
||||
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
|
||||
|
||||
FailureOr<SmallVector<int64_t, 4>> firstSizes =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||
FailureOr<SmallVector<int64_t, 4>> firstStrides =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||
if (failed(firstSizes) || failed(firstStrides))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
|
||||
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
|
||||
Value laneOperand;
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
if (failed(fragmentSizes) || failed(fragmentStrides))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|
||||
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects a uniform fragment shape and strides");
|
||||
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
Value operand;
|
||||
if (std::optional<Value> availableValue =
|
||||
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||
operand = *availableValue;
|
||||
} else {
|
||||
operand = fragmentRecord->operand;
|
||||
}
|
||||
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
|
||||
return ownerClass->op->emitError(
|
||||
"projected batch host output assembly requires source-local fragment operands");
|
||||
if (laneOperand && laneOperand != operand)
|
||||
return ownerClass->op->emitError(
|
||||
"direct batch host output assembly expects one shared lane-local fragment producer");
|
||||
laneOperand = operand;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult, 4> mixedOffsets;
|
||||
mixedOffsets.reserve(referenceSizes.size());
|
||||
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
|
||||
SmallVector<int64_t, 8> offsetsByLane;
|
||||
offsetsByLane.reserve(fragmentsByLane.size());
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
|
||||
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||
if (failed(fragmentOffsets))
|
||||
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
|
||||
offsetsByLane.push_back((*fragmentOffsets)[dim]);
|
||||
}
|
||||
mixedOffsets.push_back(allEqual(offsetsByLane)
|
||||
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
|
||||
: OpFoldResult(createLaneIndexedIndexValue(
|
||||
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
|
||||
}
|
||||
|
||||
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
|
||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
tensor::ParallelInsertSliceOp::create(state.rewriter,
|
||||
fragments.front()->loc,
|
||||
laneOperand,
|
||||
*outputArg,
|
||||
mixedOffsets,
|
||||
getStaticIndexAttrs(state.rewriter, referenceSizes),
|
||||
getStaticIndexAttrs(state.rewriter, referenceStrides));
|
||||
continue;
|
||||
}
|
||||
|
||||
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
|
||||
state.rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = fragments.front()->loc;
|
||||
SmallVector<Value, 16> reconciliatorOperands;
|
||||
SmallVector<int64_t, 16> fragmentOperandIndices;
|
||||
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
||||
SmallVector<int64_t, 64> flatOffsets;
|
||||
SmallVector<int64_t, 64> flatSizes;
|
||||
SmallVector<int64_t, 64> flatStrides;
|
||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||
DenseSet<ClassId> emittedBatchForwarding;
|
||||
|
||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||
Value operand;
|
||||
|
||||
if (std::optional<Value> availableValue =
|
||||
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||
operand = *availableValue;
|
||||
} else if (fragmentRecord->sourceClass == sourceClass.id) {
|
||||
operand = fragmentRecord->operand;
|
||||
} else {
|
||||
return sourceClass.op->emitError(
|
||||
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
|
||||
}
|
||||
|
||||
if (fragmentRecord->sourceClass != ownerClass->id) {
|
||||
if (sourceClass.isBatch && !ownerClass->isBatch) {
|
||||
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
|
||||
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||
if (!localized)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output fragment assembly is missing forwarded batch fragments");
|
||||
operand = *localized;
|
||||
} else {
|
||||
SmallVector<ProducerKey, 8> forwardedKeys;
|
||||
forwardedKeys.reserve(sourceClass.cpus.size());
|
||||
Value forwardedPayload = fragmentRecord->operand;
|
||||
for (PendingProjectedHostOutputFragment* candidate : fragments) {
|
||||
if (candidate->sourceClass != sourceClass.id)
|
||||
continue;
|
||||
if (candidate->operand != forwardedPayload)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output batch forwarding expects one shared batch payload per source class");
|
||||
forwardedKeys.push_back(candidate->producerKey);
|
||||
}
|
||||
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
|
||||
return lhs.instance.laneStart < rhs.instance.laneStart;
|
||||
});
|
||||
if (failed(emitClassToClassCommunication(
|
||||
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
|
||||
return failure();
|
||||
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||
if (!localized)
|
||||
return ownerClass->op->emitError(
|
||||
"projected host output fragment assembly failed to recover forwarded batch fragment");
|
||||
operand = *localized;
|
||||
}
|
||||
} else {
|
||||
MessageVector messages;
|
||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
|
||||
sourceClass.cpus.front(),
|
||||
"projected host output source core id");
|
||||
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
|
||||
ownerClass->cpus.front(),
|
||||
"projected host output target core id");
|
||||
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
|
||||
return failure();
|
||||
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
|
||||
return failure();
|
||||
operand = appendReceive(state,
|
||||
*ownerClass,
|
||||
cast<RankedTensorType>(operand.getType()),
|
||||
messages,
|
||||
fragmentRecord->loc);
|
||||
}
|
||||
} else if (!ownerClass->isBatch) {
|
||||
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
|
||||
state,
|
||||
*ownerClass,
|
||||
operand,
|
||||
ownerClass->op,
|
||||
"projected host output assembly tried to reuse a non-local fragment tensor");
|
||||
if (failed(localOperand))
|
||||
return failure();
|
||||
operand = *localOperand;
|
||||
}
|
||||
Value operand = fragmentRecord->publicationValue;
|
||||
|
||||
auto [operandIt, inserted] =
|
||||
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(reconciliatorOperands.size()));
|
||||
if (inserted)
|
||||
reconciliatorOperands.push_back(operand);
|
||||
fragmentOperandIndices.push_back(operandIt->second);
|
||||
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
||||
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
||||
llvm::append_range(flatSizes, fragmentRecord->sizes);
|
||||
llvm::append_range(flatStrides, fragmentRecord->strides);
|
||||
|
||||
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
|
||||
if (!operandType || !operandType.hasStaticShape())
|
||||
return ownerClass->op->emitError("projected host output assembly requires static ranked tensor operands");
|
||||
if (fragmentRecord->packedFragmentIndex >= 0) {
|
||||
int64_t fragmentSize0 = fragmentRecord->fragmentType.getDimSize(0);
|
||||
if (fragmentSize0 <= 0 || operandType.getRank() == 0)
|
||||
return ownerClass->op->emitError("packed projected host output assembly requires ranked fragment operands");
|
||||
int64_t start = fragmentRecord->packedFragmentIndex * fragmentSize0;
|
||||
int64_t end = start + fragmentSize0;
|
||||
if (start < 0 || end > operandType.getDimSize(0))
|
||||
return ownerClass->op->emitError("packed projected host output fragment index is out of bounds");
|
||||
}
|
||||
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
||||
}
|
||||
|
||||
if (reconciliatorOperands.empty())
|
||||
return ownerClass->op->emitError("missing projected host output fragments");
|
||||
return state.func.emitError("missing projected host output fragments");
|
||||
|
||||
Value input = reconciliatorOperands.front();
|
||||
ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front();
|
||||
@@ -5954,12 +5865,12 @@ owner_selected:
|
||||
state.rewriter.getStringAttr("identity"),
|
||||
state.rewriter.getStringAttr("fragment_assembly"),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices),
|
||||
state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets),
|
||||
state.rewriter.getDenseI64ArrayAttr(flatStrides),
|
||||
state.rewriter.getStringAttr("disjoint"),
|
||||
state.rewriter.getStringAttr("complete"));
|
||||
|
||||
if (failed(setHostOutputValue(state, *ownerClass, originalOutput, reconciliator.getOutput())))
|
||||
return failure();
|
||||
state.hostReplacements[originalOutput] = reconciliator.getOutput();
|
||||
}
|
||||
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user