cose
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-25 18:57:12 +02:00
parent be0bcc9dcc
commit 568fd90542
6 changed files with 647 additions and 179 deletions
@@ -5,6 +5,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include <limits>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
@@ -138,26 +139,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
Location loc,
ShapedType destinationType,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<int64_t> additionalOffsets,
IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
}
else {
scaledOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
(intAttr.getInt() + additionalOffsets[dim]) * scale);
} else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
if (additionalOffsets[dim] != 0) {
Value staticOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
additionalOffsets[dim] * scale);
scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult();
}
}
totalOffset =
@@ -169,6 +193,127 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
return totalOffset;
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
SmallVector<int64_t> zeroOffsets(destinationType.getRank(), 0);
return createHostTargetOffset(rewriter,
insertSlice.getLoc(),
destinationType,
insertSlice.getMixedOffsets(),
zeroOffsets,
mapper);
}
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
Location loc,
ArrayRef<OpFoldResult> baseOffsets,
ArrayRef<int64_t> fragmentOffsets,
IRMapping& mapper) {
SmallVector<OpFoldResult, 4> combined;
combined.reserve(fragmentOffsets.size());
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
int64_t base = cast<IntegerAttr>(attr).getInt();
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
continue;
}
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
if (fragmentOffsets[dim] == 0) {
combined.push_back(dynamicBase);
continue;
}
Value staticOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
}
return combined;
}
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
Value hostTarget,
ArrayRef<OpFoldResult> baseOffsets,
IRMapping& mapper) {
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return reconciliator.emitOpError(
"fragment assembly lowering requires explicit operand indices and unit strides");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
reconciliator.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
hostTarget = tensor::InsertSliceOp::create(rewriter,
reconciliator.getLoc(),
fragment,
hostTarget,
buildFragmentOffsets(rewriter,
reconciliator.getLoc(),
baseOffsets,
fragmentOffsets,
mapper),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
return hostTarget;
}
} // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
@@ -207,8 +352,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
if (result.use_empty())
continue;
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
@@ -271,6 +418,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
for (Operation* user : reconciliator.getOutput().getUsers()) {
if (!isa<tensor::ParallelInsertSliceOp>(user))
return reconciliator.emitOpError(
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
}
continue;
}
}
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
if (!firstOutputArg)
@@ -287,10 +446,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
continue;
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
if (auto reconciliator =
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
reconciliator,
hostTarget,
insertSlice.getMixedOffsets(),
mapper);
if (failed(updatedHostTarget))
return failure();
hostOutputTensors[resultIndex] = *updatedHostTarget;
continue;
}
}
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
@@ -30,6 +30,91 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op);
}
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
Location loc,
ShapedType destinationType,
ArrayRef<int64_t> fragmentOffsets) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
int64_t byteOffset = 0;
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
byteOffset += offset * strides[dim] * elementBytes;
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
}
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
IRMapping& mapping) {
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<StringRef> modeAttr = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t fragmentBytes =
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
reconciliator.getOperation(),
fragmentBytes,
"fragment assembly host copy size");
if (failed(sizeAttr))
return failure();
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
packedFragmentOrdinal * fragmentBytes);
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
currentOutput.getType(),
hostTargetOffset,
deviceSourceOffset,
currentOutput,
source,
*sizeAttr)
.getOutput();
}
return currentOutput;
}
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
@@ -131,6 +216,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) {
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
if (failed(lowered))
return false;
mapping.map(reconciliator.getOutput(), *lowered);
continue;
}
}
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
@@ -1,6 +1,10 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -11,6 +15,107 @@ namespace raptor {
} // namespace raptor
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
struct LowerFragmentAssemblyReconciliatorPattern
: OpConversionPattern<spatial::SpatReconciliatorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
std::optional<StringRef> modeAttr = op.getMode();
if (!modeAttr || *modeAttr != "fragment_assembly")
return failure();
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return op.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = fragmentOperands[operandIndex];
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
currentOutput = tensor::InsertSliceOp::create(rewriter,
op.getLoc(),
fragment,
currentOutput,
getStaticIndexAttrs(rewriter, fragmentOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
rewriter.replaceOp(op, currentOutput);
return success();
}
};
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -306,10 +306,12 @@ struct ProjectedExtractReplacement {
struct PendingProjectedHostOutputFragment {
Value originalOutput;
ClassId sourceClass = 0;
ProducerKey producerKey;
Value operand;
RankedTensorType operandType;
RankedTensorType fragmentType;
int64_t packedFragmentIndex = -1;
int64_t currentLane = -1;
SmallVector<int64_t, 4> offsets;
SmallVector<int64_t, 4> sizes;
SmallVector<int64_t, 4> strides;
@@ -1137,6 +1139,59 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
return success();
}
void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
Block& funcBlock = state.func.getBody().front();
for (Operation& op : funcBlock) {
if (state.oldComputeOps.contains(&op)) {
state.rewriter.setInsertionPoint(&op);
return;
}
}
state.rewriter.setInsertionPointToEnd(&funcBlock);
}
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
DenseSet<CpuId> usedCpus;
for (const auto& [cpu, _] : state.cpuToClass)
usedCpus.insert(cpu);
CpuId assemblyCpu = 0;
while (usedCpus.contains(assemblyCpu))
++assemblyCpu;
setInsertionPointForNewMaterializedOp(state);
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
compute.getProperties().setOperandSegmentSizes({0, 0});
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
if (failed(coreIdAttr))
return failure();
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
Block* body = state.rewriter.createBlock(&compute.getBody());
state.rewriter.setInsertionPointToEnd(body);
Value placeholder =
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
state.rewriter.setInsertionPointAfter(compute.getOperation());
MaterializedClass materializedClass;
materializedClass.id = state.classes.size();
materializedClass.cpus.push_back(assemblyCpu);
materializedClass.op = compute.getOperation();
materializedClass.body = body;
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
materializedClass.hostOutputs.push_back(originalOutput);
state.cpuToClass[assemblyCpu] = materializedClass.id;
state.hostOutputOwners[originalOutput] = materializedClass.id;
state.classes.push_back(std::move(materializedClass));
return state.classes.back().id;
}
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
auto it = materializedClass.weightArgs.find(weight);
if (it != materializedClass.weightArgs.end())
@@ -1897,6 +1952,14 @@ FailureOr<SmallVector<OpFoldResult, 4>> buildProjectedFragmentOffsetsInClass(Mat
return fragmentOffsets;
}
SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
Value createDim0InsertSlice(
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
auto fragmentType = cast<RankedTensorType>(fragment.getType());
@@ -3639,6 +3702,9 @@ LogicalResult appendSend(MaterializerState& state,
if (sourceClass.isBatch) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
if (messages.size() != sourceClass.cpus.size())
return sourceClass.op->emitError("batch send expects exactly one message per materialized lane")
<< " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size();
Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc);
@@ -3686,6 +3752,11 @@ Value appendReceive(
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
if (targetClass.isBatch) {
if (messages.size() != targetClass.cpus.size()) {
targetClass.op->emitOpError("batch receive expects exactly one message per materialized lane")
<< " messageCount=" << messages.size() << " laneCount=" << targetClass.cpus.size();
return Value();
}
Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc);
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc);
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc);
@@ -5481,10 +5552,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
originalOutput,
sourceClass.id,
ProducerKey {peer, resultIndex},
packed,
cast<RankedTensorType>(packed.getType()),
fragmentType,
static_cast<int64_t>(runIndex),
static_cast<int64_t>(runIndex),
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5572,10 +5645,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
originalOutput,
sourceClass.id,
key,
packed,
packedType,
fragmentType,
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
static_cast<int64_t>(fragmentIndex),
SmallVector<int64_t, 4>(*offsets),
SmallVector<int64_t, 4>(*sizes),
SmallVector<int64_t, 4>(*strides),
@@ -5611,16 +5686,6 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
}
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
if (ownerClass->isBatch) {
auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) {
return !candidate.isBatch;
});
if (scalarOwnerIt == state.classes.end())
return ownerClass->op->emitError(
"projected host output finalization requires a scalar assembly class when the preferred host owner is batch");
ownerClass = &*scalarOwnerIt;
state.hostOutputOwners[originalOutput] = ownerClass->id;
}
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
@@ -5646,6 +5711,119 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
if (allFromSameSourceClass) {
ownerClass = &state.classes[fragments.front()->sourceClass];
state.hostOutputOwners[originalOutput] = ownerClass->id;
} else {
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
goto owner_selected;
FailureOr<ClassId> createdOwner =
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
if (failed(createdOwner))
return failure();
ownerClass = &state.classes[*createdOwner];
}
owner_selected:
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
if (failed(sourceProjection) || !sourceLaneArg)
return ownerClass->op->emitError(
"direct batch host output assembly requires the source batch projection metadata");
auto outputArg = batch.getOutputArgument(resultIt->second);
auto laneArg = batch.getLaneArgument();
if (!outputArg || !laneArg)
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
if (fragments.size() != ownerClass->cpus.size())
return ownerClass->op->emitError(
"direct batch host output assembly expects exactly one fragment per materialized lane");
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
if (fragmentsByLane[currentLane])
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
fragmentsByLane[currentLane] = fragmentRecord;
}
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
FailureOr<SmallVector<int64_t, 4>> firstSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
FailureOr<SmallVector<int64_t, 4>> firstStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
if (failed(firstSizes) || failed(firstStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
Value laneOperand;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentSizes) || failed(fragmentStrides))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
return ownerClass->op->emitError(
"direct batch host output assembly expects a uniform fragment shape and strides");
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else {
operand = fragmentRecord->operand;
}
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
return ownerClass->op->emitError(
"projected batch host output assembly requires source-local fragment operands");
if (laneOperand && laneOperand != operand)
return ownerClass->op->emitError(
"direct batch host output assembly expects one shared lane-local fragment producer");
laneOperand = operand;
}
SmallVector<OpFoldResult, 4> mixedOffsets;
mixedOffsets.reserve(referenceSizes.size());
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
SmallVector<int64_t, 8> offsetsByLane;
offsetsByLane.reserve(fragmentsByLane.size());
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
if (failed(fragmentOffsets))
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
offsetsByLane.push_back((*fragmentOffsets)[dim]);
}
mixedOffsets.push_back(allEqual(offsetsByLane)
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
: OpFoldResult(createLaneIndexedIndexValue(
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
}
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
tensor::ParallelInsertSliceOp::create(state.rewriter,
fragments.front()->loc,
laneOperand,
*outputArg,
mixedOffsets,
getStaticIndexAttrs(state.rewriter, referenceSizes),
getStaticIndexAttrs(state.rewriter, referenceStrides));
continue;
}
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
@@ -5656,28 +5834,73 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
SmallVector<int64_t, 64> flatSizes;
SmallVector<int64_t, 64> flatStrides;
DenseMap<Value, int64_t> operandIndicesByValue;
DenseSet<ClassId> emittedBatchForwarding;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
Value operand = fragmentRecord->operand;
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
Value operand;
if (std::optional<Value> availableValue =
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
operand = *availableValue;
} else if (fragmentRecord->sourceClass == sourceClass.id) {
operand = fragmentRecord->operand;
} else {
return sourceClass.op->emitError(
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
}
if (fragmentRecord->sourceClass != ownerClass->id) {
if (sourceClass.isBatch || ownerClass->isBatch)
return sourceClass.op->emitError(
"projected host output fragment assembly requires scalarized cross-class operands before finalization");
MessageVector messages;
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
sourceClass.cpus.front(),
"projected host output source core id");
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
ownerClass->cpus.front(),
"projected host output target core id");
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
return failure();
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
return failure();
operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc);
if (sourceClass.isBatch && !ownerClass->isBatch) {
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly is missing forwarded batch fragments");
operand = *localized;
} else {
SmallVector<ProducerKey, 8> forwardedKeys;
forwardedKeys.reserve(sourceClass.cpus.size());
Value forwardedPayload = fragmentRecord->operand;
for (PendingProjectedHostOutputFragment* candidate : fragments) {
if (candidate->sourceClass != sourceClass.id)
continue;
if (candidate->operand != forwardedPayload)
return ownerClass->op->emitError(
"projected host output batch forwarding expects one shared batch payload per source class");
forwardedKeys.push_back(candidate->producerKey);
}
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
return lhs.instance.laneStart < rhs.instance.laneStart;
});
if (failed(emitClassToClassCommunication(
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
return failure();
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
if (!localized)
return ownerClass->op->emitError(
"projected host output fragment assembly failed to recover forwarded batch fragment");
operand = *localized;
}
} else {
MessageVector messages;
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
sourceClass.cpus.front(),
"projected host output source core id");
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
ownerClass->cpus.front(),
"projected host output target core id");
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
return failure();
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
return failure();
operand = appendReceive(state,
*ownerClass,
cast<RankedTensorType>(operand.getType()),
messages,
fragmentRecord->loc);
}
} else if (!ownerClass->isBatch) {
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
state,