roba
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-26 13:02:38 +02:00
parent 568fd90542
commit 984f362623
13 changed files with 797 additions and 347 deletions
@@ -113,6 +113,19 @@ void verifyScheduledInputs(ComputeOpTy compute,
}
}
template <typename ComputeOpTy>
void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute,
pim::CappedDiagnosticReporter& diagnostics) {
compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) {
std::optional<StringRef> mode = reconciliator.getMode();
if (!mode || *mode != "fragment_assembly")
return;
diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) {
illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization");
});
});
}
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp,
@@ -188,10 +201,14 @@ LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyScheduledTopLevelOps(funcOp, diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics);
}
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics);
}
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
@@ -1702,8 +1702,6 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) {
state.outWidth);
if (!tiling)
return false;
if (tiling->numChannelTiles > static_cast<int64_t>(crossbarCountInCore.getValue()))
return false;
if (!state.hasBias)
return true;
@@ -107,6 +107,7 @@ static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewr
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
}
@@ -131,6 +131,151 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
return result.getUses().begin()->getOperandNumber();
}
struct BatchFragmentAssemblyPlan {
unsigned returnIndex = 0;
int64_t localSourceElementOffset = 0;
int64_t fragmentByteSize = 0;
SmallVector<int64_t, 8> hostOffsetsByLane;
};
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) {
assert(!values.empty() && "expected lane-indexed values");
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
return getOrCreateIndexConstant(rewriter, anchor, values.front());
if (values.size() >= 2) {
int64_t step = values[1] - values[0];
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
return values[index] == values.front() + static_cast<int64_t>(index) * step;
});
if (arithmetic) {
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front());
if (step == 0)
return base;
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult();
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult();
}
}
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front());
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
}
return selected;
}
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>>
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
SmallVector<BatchFragmentAssemblyPlan, 8> plans;
if (!packedResultType.hasStaticShape() || laneCount == 0)
return failure();
int64_t packedElementCount = packedResultType.getNumElements();
if (packedElementCount % static_cast<int64_t>(laneCount) != 0)
return failure();
int64_t payloadElementCount = packedElementCount / static_cast<int64_t>(laneCount);
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
for (OpOperand& use : result.getUses()) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
return failure();
std::optional<StringRef> mode = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
return failure();
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
return failure();
unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
auto hostResultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!hostResultType || !hostResultType.hasStaticShape())
return failure();
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *stridesAttr;
int64_t rank = hostResultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
rank,
fragmentOperands.size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return failure();
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
continue;
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
int64_t lane = sourceElementOffset / payloadElementCount;
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
return failure();
SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return failure();
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentSizes.push_back(flatSizes[flatIndex]);
}
if (failed(forEachContiguousDestinationChunk(
hostResultType.getShape(),
fragmentOffsets,
fragmentSizes,
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
int64_t hostElementOffset = 0;
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
hostElementOffset += offset * hostStrides[dim];
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) {
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset
&& plan.fragmentByteSize == fragmentByteSize;
});
if (planIt == plans.end()) {
BatchFragmentAssemblyPlan plan;
plan.returnIndex = returnIndex;
plan.localSourceElementOffset = chunkSourceOffset;
plan.fragmentByteSize = fragmentByteSize;
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
plans.push_back(std::move(plan));
return success();
}
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
return success();
})))
return failure();
}
}
for (const BatchFragmentAssemblyPlan& plan : plans)
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
return failure();
return plans;
}
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
@@ -250,6 +395,10 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
"fragment assembly lowering requires explicit operand indices and unit strides");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
if (!sourceOffsetsAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets");
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
@@ -257,21 +406,25 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
rank,
fragmentOperands.size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return failure();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
@@ -279,20 +432,21 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
reconciliator.getLoc(),
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
reconciliator.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
@@ -351,16 +505,29 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
SmallVector<unsigned> returnOperandIndices;
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
if (result.use_empty())
continue;
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
if (succeeded(returnOperandIndex)) {
returnOperandIndices[resultIndex] = *returnOperandIndex;
continue;
}
auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType || !resultType.hasStaticShape())
return computeBatchOp.emitOpError(
"resultful compute_batch publication lowering requires static ranked tensor results");
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans =
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
if (failed(fragmentAssemblyPlans))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
returnOperandIndices[resultIndex] = *returnOperandIndex;
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end());
}
}
@@ -446,9 +613,44 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size()
&& !fragmentAssemblyPlansByResult[resultIndex].empty();
if (!hasDirectReturn && !hasFragmentAssembly)
continue;
Value mappedSource = mapper.lookup(insertSlice.getSource());
if (hasFragmentAssembly) {
BlockArgument laneArg = coreBatchOp.getLaneArgument();
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) {
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc());
auto sizeAttr = pim::getCheckedI32Attr(
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size");
if (failed(sizeAttr))
return failure();
Value hostTargetOffset =
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
Value deviceSourceOffset = getOrCreateIndexConstant(
rewriter, coreBatchOp.getOperation(),
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
outputTensor =
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
outputTensor.getType(),
hostTargetOffset,
deviceSourceOffset,
outputTensor,
mappedSource,
*sizeAttr)
.getOutput();
}
continue;
}
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
if (auto reconciliator =
@@ -467,7 +669,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
}
}
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
+114
View File
@@ -5,6 +5,7 @@
#include <cassert>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
using namespace llvm;
@@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
int64_t resultRank,
size_t operandCount,
ArrayRef<int64_t> operandIndices,
ArrayRef<int64_t> sourceOffsets,
ArrayRef<int64_t> flatOffsets,
ArrayRef<int64_t> flatSizes,
ArrayRef<int64_t> flatStrides) {
if (operandIndices.size() != sourceOffsets.size())
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
if (flatOffsets.size() != flatSizes.size())
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
if (flatStrides.size() != flatOffsets.size())
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
if (sourceOffsets[fragmentIndex] < 0)
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
}
return success();
}
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
FailureOr<SmallVector<int64_t, 4>>
getStaticSliceOffsetsForElementOffset(Operation* anchor,
ShapedType sourceType,
ArrayRef<int64_t> fragmentShape,
int64_t sourceElementOffset,
StringRef fieldName) {
if (!sourceType.hasStaticShape())
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
if (sourceElementOffset < 0)
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
int64_t sourceElementCount = sourceType.getNumElements();
int64_t fragmentElementCount = 1;
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (fragmentShape[dim] < 0)
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
fragmentElementCount *= fragmentShape[dim];
}
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
}
return sliceOffsets;
}
LogicalResult
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
ArrayRef<int64_t> baseOffsets,
ArrayRef<int64_t> sizes,
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
int64_t rank = static_cast<int64_t>(sizes.size());
int64_t suffixStart = rank - 1;
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
--suffixStart;
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
suffixStart = 0;
else
++suffixStart;
int64_t chunkElements = 1;
for (int64_t dim = suffixStart; dim < rank; ++dim)
chunkElements *= sizes[dim];
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
int64_t sourceChunkOrdinal = 0;
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
if (dim == static_cast<int64_t>(prefixExtents.size())) {
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
chunkOffsets[prefixDim] += current[prefixDim];
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
return failure();
++sourceChunkOrdinal;
return success();
}
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
current[dim] = index;
if (failed(visit(visit, dim + 1)))
return failure();
}
return success();
};
if (prefixExtents.empty())
return callback(baseOffsets, 0, chunkElements);
return visit(visit, 0);
}
} // namespace onnx_mlir
@@ -1,10 +1,17 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir::spatial {
class SpatReconciliatorOp;
}
namespace onnx_mlir {
mlir::FailureOr<mlir::IntegerAttr>
@@ -29,6 +36,29 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator,
int64_t resultRank,
size_t operandCount,
llvm::ArrayRef<int64_t> operandIndices,
llvm::ArrayRef<int64_t> sourceOffsets,
llvm::ArrayRef<int64_t> flatOffsets,
llvm::ArrayRef<int64_t> flatSizes,
llvm::ArrayRef<int64_t> flatStrides);
mlir::FailureOr<mlir::SmallVector<int64_t, 4>>
getStaticSliceOffsetsForElementOffset(mlir::Operation* anchor,
mlir::ShapedType sourceType,
llvm::ArrayRef<int64_t> fragmentShape,
int64_t sourceElementOffset,
llvm::StringRef fieldName);
mlir::LogicalResult
forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
llvm::ArrayRef<int64_t> baseOffsets,
llvm::ArrayRef<int64_t> sizes,
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
callback);
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
@@ -52,11 +52,14 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
std::optional<StringRef> modeAttr = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|| !fragmentStridesAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
@@ -64,13 +67,19 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
rank,
fragmentOperands.size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return failure();
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
@@ -96,11 +105,16 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
if (failed(sizeAttr))
return failure();
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
reconciliator,
"fragment assembly device source offset");
if (failed(deviceSourceOffsetBytes))
return failure();
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
packedFragmentOrdinal * fragmentBytes);
static_cast<int64_t>(*deviceSourceOffsetBytes));
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
currentOutput.getType(),
+12 -11
View File
@@ -47,11 +47,13 @@ struct LowerFragmentAssemblyReconciliatorPattern
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
@@ -59,23 +61,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
if (failed(validateFragmentAssemblyMetadata(
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
return failure();
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return op.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = fragmentOperands[operandIndex];
@@ -83,20 +83,21 @@ struct LowerFragmentAssemblyReconciliatorPattern
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
@@ -149,6 +149,40 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
};
}
static FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>>
analyzeTopLevelFragmentAssemblyUses(Value value) {
SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4> uses;
for (OpOperand& use : value.getUses()) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
return failure();
std::optional<StringRef> mode = reconciliator.getMode();
if (!mode || *mode != "fragment_assembly")
return failure();
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
return failure();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
return failure();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
resultType.getRank(),
fragmentOperands.size(),
*operandIndicesAttr,
*sourceOffsetsAttr,
reconciliator.getFragmentOffsets(),
reconciliator.getFragmentSizes(),
*stridesAttr)))
return failure();
uses.emplace_back(reconciliator, use.getOperandNumber());
}
return uses;
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation* op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
@@ -559,6 +593,116 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
}
FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>> fragmentAssemblyUses =
analyzeTopLevelFragmentAssemblyUses(producedValue);
if (succeeded(fragmentAssemblyUses)) {
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
if (!sourceType || !sourceType.hasStaticShape()) {
producerOp->emitOpError("fragment assembly publication requires a static ranked tensor source");
return ReturnPathLoweringResult::Failure;
}
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) {
rewriter.setInsertionPointAfterValue(storedValue);
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
reconciliator.emitOpError(
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
return ReturnPathLoweringResult::Failure;
}
size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!outputType || !resultType || !resultType.hasStaticShape()) {
reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs");
return ReturnPathLoweringResult::Failure;
}
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *stridesAttr;
int64_t rank = resultType.getRank();
if (failed(validateFragmentAssemblyMetadata(reconciliator,
rank,
1 + reconciliator.getFragments().size(),
operandIndices,
sourceOffsets,
flatOffsets,
flatSizes,
flatStrides)))
return ReturnPathLoweringResult::Failure;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
continue;
SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1) {
reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
return ReturnPathLoweringResult::Failure;
}
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentSizes.push_back(flatSizes[flatIndex]);
}
bool failedChunk = false;
if (failed(forEachContiguousDestinationChunk(
outputType.getShape(),
fragmentOffsets,
fragmentSizes,
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
auto hostOffset =
getCheckedByteOffset(computeFlatElementIndex(chunkOffsets, outputType.getShape()),
elementSize,
producerOp,
"fragment assembly host offset");
auto sourceOffset = getCheckedByteOffset(sourceOffsets[fragmentIndex] + relativeSourceOffset,
elementSize,
producerOp,
"fragment assembly source offset");
auto fragmentBytes =
getCheckedByteOffset(chunkElements, elementSize, producerOp, "fragment assembly host copy byte size");
if (failed(hostOffset) || failed(sourceOffset) || failed(fragmentBytes)) {
failedChunk = true;
return failure();
}
auto sizeAttr =
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size");
if (failed(sizeAttr)) {
failedChunk = true;
return failure();
}
outputTensor =
pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
outputTensor.getType(),
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
outputTensor,
storedValue,
*sizeAttr)
.getOutput();
return success();
})))
failedChunk = true;
if (failedChunk)
return ReturnPathLoweringResult::Failure;
}
markOpToRemove(reconciliator.getOperation());
}
return ReturnPathLoweringResult::Handled;
}
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
auto storedByteSize =
@@ -669,6 +813,16 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return;
}
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> mode = reconciliator.getMode();
if (mode && *mode == "fragment_assembly") {
markOpToRemove(reconciliator.getOperation());
for (Value operand : reconciliator->getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
}
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
markOpToRemove(computeOp);
if (!computeOp.getInputs().empty())