Bose
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
ilgeco
2026-06-26 17:45:27 +02:00
parent 984f362623
commit 78e97f9fd8
23 changed files with 513 additions and 17489 deletions
+3 -2
View File
@@ -232,8 +232,8 @@ def SpatReluPlanOp : SpatOp<"relu_plan", []> {
let hasVerifier = 1;
}
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
let summary = "Logical-to-physical layout record or explicit fragment assembly";
def SpatBlueprintOp : SpatOp<"blueprint", []> {
let summary = "Blueprint for assembling logical tensors from published fragments";
let arguments = (ins
SpatTensor:$input,
@@ -256,6 +256,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
+133
View File
@@ -32,6 +32,14 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) {
StringRef value;
if (parser.parseKeyword(&value))
return failure();
attr = parser.getBuilder().getStringAttr(value);
return success();
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
@@ -466,6 +474,131 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success();
}
void SpatBlueprintOp::print(OpAsmPrinter& printer) {
SmallVector<Value> operands {getInput()};
llvm::append_range(operands, getFragments());
printer << " fragments";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
printer << " layout " << getLogicalLayout();
printer << " physical " << getPhysicalLayout();
printer << " offsets ";
printCompressedIntegerList(printer, getFragmentOffsets());
printer << " sizes ";
printCompressedIntegerList(printer, getFragmentSizes());
printer << " map " << getIndexMap();
if (std::optional<StringRef> mode = getMode())
printer << " mode " << *mode;
if (std::optional<ArrayRef<int64_t>> operandIndices = getFragmentOperandIndices()) {
printer << " operandIndices ";
printCompressedIntegerList(printer, *operandIndices);
}
if (std::optional<ArrayRef<int64_t>> sourceOffsets = getFragmentSourceOffsets()) {
printer << " sourceOffsets ";
printCompressedIntegerList(printer, *sourceOffsets);
}
if (std::optional<ArrayRef<int64_t>> strides = getFragmentStrides()) {
printer << " strides ";
printCompressedIntegerList(printer, *strides);
}
if (std::optional<StringRef> conflictPolicy = getConflictPolicy())
printer << " conflict " << *conflictPolicy;
if (std::optional<StringRef> coveragePolicy = getCoveragePolicy())
printer << " coverage " << *coveragePolicy;
printer.printOptionalAttrDict((*this)->getAttrs(),
{getLogicalLayoutAttrName().getValue(),
getPhysicalLayoutAttrName().getValue(),
getFragmentOffsetsAttrName().getValue(),
getFragmentSizesAttrName().getValue(),
getIndexMapAttrName().getValue(),
getModeAttrName().getValue(),
getFragmentOperandIndicesAttrName().getValue(),
getFragmentSourceOffsetsAttrName().getValue(),
getFragmentStridesAttrName().getValue(),
getConflictPolicyAttrName().getValue(),
getCoveragePolicyAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren);
printer << " -> ";
printer.printType(getOutput().getType());
}
ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;
Type outputType;
StringAttr logicalLayout;
StringAttr physicalLayout;
StringAttr indexMap;
StringAttr mode;
StringAttr conflictPolicy;
StringAttr coveragePolicy;
SmallVector<int64_t> fragmentOffsets;
SmallVector<int64_t> fragmentSizes;
SmallVector<int64_t> fragmentOperandIndices;
SmallVector<int64_t> fragmentSourceOffsets;
SmallVector<int64_t> fragmentStrides;
if (parser.parseKeyword("fragments")
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)
|| parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout)
|| parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout)
|| parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets)
|| parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes)
|| parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap))
return failure();
if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode))
return failure();
if (succeeded(parser.parseOptionalKeyword("operandIndices"))
&& parseCompressedIntegerList(parser, fragmentOperandIndices))
return failure();
if (succeeded(parser.parseOptionalKeyword("sourceOffsets"))
&& parseCompressedIntegerList(parser, fragmentSourceOffsets))
return failure();
if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides))
return failure();
if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy))
return failure();
if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseType(outputType))
return failure();
if (operands.empty())
return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand");
if (operands.size() != operandTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match");
auto& builder = parser.getBuilder();
result.addAttribute("logicalLayout", logicalLayout);
result.addAttribute("physicalLayout", physicalLayout);
result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets));
result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes));
result.addAttribute("indexMap", indexMap);
if (mode)
result.addAttribute("mode", mode);
if (!fragmentOperandIndices.empty())
result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices));
if (!fragmentSourceOffsets.empty())
result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets));
if (!fragmentStrides.empty())
result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides));
if (conflictPolicy)
result.addAttribute("conflictPolicy", conflictPolicy);
if (coveragePolicy)
result.addAttribute("coveragePolicy", coveragePolicy);
if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputType);
return success();
}
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
+15 -15
View File
@@ -436,10 +436,10 @@ LogicalResult SpatReluPlanOp::verify() {
return success();
}
LogicalResult SpatReconciliatorOp::verify() {
LogicalResult SpatBlueprintOp::verify() {
auto modeAttr = getModeAttr();
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.blueprint")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
@@ -482,10 +482,10 @@ LogicalResult SpatReconciliatorOp::verify() {
if (failed(verifyBoundsOnly({})))
return failure();
if (!getFragments().empty())
return emitError("legacy reconciliator does not accept extra fragment operands");
return emitError("legacy blueprint does not accept extra fragment operands");
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|| getCoveragePolicyAttr())
return emitError("legacy reconciliator does not accept fragment assembly attributes");
return emitError("legacy blueprint does not accept fragment assembly attributes");
return success();
}
@@ -493,11 +493,11 @@ LogicalResult SpatReconciliatorOp::verify() {
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
if (!operandIndicesAttr)
return emitError("fragment assembly reconciliator requires fragment operand indices");
return emitError("fragment assembly blueprint requires fragment operand indices");
if (!sourceOffsetsAttr)
return emitError("fragment assembly reconciliator requires fragment source offsets");
return emitError("fragment assembly blueprint requires fragment source offsets");
if (!stridesAttr)
return emitError("fragment assembly reconciliator requires fragment strides");
return emitError("fragment assembly blueprint requires fragment strides");
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
@@ -506,11 +506,11 @@ LogicalResult SpatReconciliatorOp::verify() {
if (sourceOffsets.size() != operandIndices.size())
return emitError("fragment source offset count must match fragment operand index count");
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
return emitError("fragment assembly blueprint requires conflict and coverage policies");
if (getConflictPolicy() != "disjoint")
return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\"");
return emitError("fragment assembly blueprint currently supports only conflict_policy=\"disjoint\"");
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\"");
return emitError("fragment assembly blueprint coverage_policy must be \"complete\" or \"partial\"");
SmallVector<Value> operands;
operands.push_back(getInput());
@@ -518,7 +518,7 @@ LogicalResult SpatReconciliatorOp::verify() {
int64_t operandCount = static_cast<int64_t>(operands.size());
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
if (operandCount == 0)
return emitError("fragment assembly reconciliator requires at least one operand");
return emitError("fragment assembly blueprint requires at least one operand");
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
return emitError("fragment assembly metadata count must match operand count * result rank");
if (failed(verifyBoundsOnly(strides)))
@@ -544,9 +544,9 @@ LogicalResult SpatReconciliatorOp::verify() {
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
if (!operandType || !operandType.hasStaticShape())
return emitError("fragment assembly reconciliator requires static ranked tensor operands");
return emitError("fragment assembly blueprint requires static ranked tensor operands");
if (operandType.getRank() != rank)
return emitError("fragment assembly reconciliator requires operand/result rank match");
return emitError("fragment assembly blueprint requires operand/result rank match");
SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes;
@@ -583,14 +583,14 @@ LogicalResult SpatReconciliatorOp::verify() {
}
}
if (overlaps)
return emitError("fragment assembly reconciliator requires disjoint static slices");
return emitError("fragment assembly blueprint requires disjoint static slices");
}
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
}
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
return emitError("fragment assembly blueprint requires every operand to contribute at least one fragment");
}
if (getCoveragePolicy() == "complete") {
@@ -30,6 +30,7 @@
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -308,7 +309,7 @@ struct PendingProjectedHostOutputFragment {
Value originalOutput;
ClassId sourceClass = 0;
ProducerKey producerKey;
Value publicationValue;
unsigned publicationResultIndex = 0;
int64_t sourceFragmentOrdinal = 0;
int64_t sourceElementOffset = 0;
SmallVector<int64_t, 4> offsets;
@@ -1220,36 +1221,13 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
return std::get<1>(*arg);
}
void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state,
Operation* oldOwner,
Operation* newOwner) {
if (!oldOwner || oldOwner == newOwner)
return;
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) {
auto publicationResult = dyn_cast_or_null<OpResult>(fragment.publicationValue);
if (!publicationResult || publicationResult.getOwner() != oldOwner)
publicationResult = OpResult();
else
fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber());
if (auto originalResult = dyn_cast_or_null<OpResult>(fragment.originalOutput); originalResult
&& originalResult.getOwner() == oldOwner) {
fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber());
}
if (fragment.producerKey.instance.op == oldOwner)
fragment.producerKey.instance.op = newOwner;
}
}
FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
FailureOr<unsigned> appendScalarPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
if (existing != materializedClass.publicationOutputToResultIndex.end())
return materializedClass.op->getResult(existing->second);
return existing->second;
auto compute = dyn_cast<SpatScheduledCompute>(materializedClass.op);
if (!compute)
@@ -1264,27 +1242,25 @@ FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
if (failed(inserted))
return materializedClass.op->emitError("failed to append scalar publication result");
Operation* oldOp = materializedClass.op;
auto [result, newCompute] = *inserted;
materializedClass.op = newCompute.getOperation();
materializedClass.body = &newCompute.getBody().front();
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
if (!yieldOp)
return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result");
state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); });
return result;
return result.getResultNumber();
}
FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
FailureOr<unsigned> appendBatchPublicationResult(MaterializerState& state,
MaterializedClass& materializedClass,
Value payload,
Location loc) {
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
if (existing != materializedClass.publicationOutputToResultIndex.end())
return materializedClass.op->getResult(existing->second);
return existing->second;
auto batch = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op);
if (!batch)
@@ -1305,11 +1281,9 @@ FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
if (failed(inserted))
return materializedClass.op->emitError("failed to append batch publication result");
Operation* oldOp = materializedClass.op;
auto [result, outputArg, newBatch] = *inserted;
materializedClass.op = newBatch.getOperation();
materializedClass.body = &newBatch.getBody().front();
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
auto inParallelOp = dyn_cast<SpatInParallelOp>(materializedClass.body->getTerminator());
@@ -1330,7 +1304,7 @@ FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
Value firstOffset =
scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc);
createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset);
return result;
return result.getResultNumber();
}
// -----------------------------------------------------------------------------
@@ -1563,7 +1537,7 @@ void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value val
void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) {
Block& body = *targetClass.body;
diagnostic.attachNote(targetClass.op->getLoc())
<< "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName()
<< "target class " << targetClass.id << " op '" << targetClass.op->getName()
<< "' body has " << body.getNumArguments() << " block arguments and "
<< std::distance(body.begin(), body.end()) << " top-level operations";
}
@@ -1687,7 +1661,7 @@ FailureOr<Value> rematerializeIndexValueInClass(MaterializerState& state,
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
InFlightDiagnostic diagnostic = targetClass.op->emitError(
"RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body");
"cannot rematerialize external block argument in materialized class body");
diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType()
<< " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'";
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
@@ -1709,16 +1683,16 @@ FailureOr<Value> rematerializeIndexValueInClass(MaterializerState& state,
if (mapperHadOriginalValue && mappedOriginalValue != value)
attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value");
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op");
attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain");
attachMaterializerOperationPrintNote(diagnostic, owner, "external block argument owner op");
attachMaterializerParentChainNote(diagnostic, owner, "external block argument owner parent chain");
}
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op");
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op");
attachMaterializedClassBodySummary(diagnostic, targetClass);
return failure();
}
InFlightDiagnostic diagnostic =
targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body");
targetClass.op->emitError("cannot rematerialize external index value in materialized class body");
diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='"
<< targetClass.op->getName() << "'";
attachMaterializerValueOriginNote(diagnostic, originalValue, "original value");
@@ -1793,8 +1767,12 @@ FailureOr<Value> rematerializeTensorValueInClass(MaterializerState& state,
strides.push_back(*localized);
}
return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides)
.getResult();
auto resultType = dyn_cast<RankedTensorType>(extractSlice.getResult().getType());
if (!resultType)
return anchor->emitError("expected ranked tensor extract_slice while rematerializing tensor capture");
return extractStaticSliceOrIdentity(
state.rewriter, anchor->getLoc(), *localizedSource, resultType, offsets, sizes, strides);
}
if (auto collapseShape = value.getDefiningOp<tensor::CollapseShapeOp>()) {
@@ -2108,8 +2086,10 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in
if (dim0Size == 1)
return index;
Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size);
return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult();
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0 * dim0Size);
return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, anchor);
}
FailureOr<Value> scaleIndexByDim0SizeInClass(MaterializerState& state,
@@ -2123,8 +2103,7 @@ FailureOr<Value> scaleIndexByDim0SizeInClass(MaterializerState& state,
if (dim0Size == 1)
return *localizedIndex;
Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size);
return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult();
return scaleIndexByDim0Size(state, targetClass.op, *localizedIndex, dim0Size, loc);
}
bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) {
@@ -3677,10 +3656,13 @@ FailureOr<Value> buildProjectedPackedPayload(MaterializerState& state,
ValueRange {init},
[&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value acc = iterArgs.front();
Value payloadFragmentCount =
getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount);
Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult();
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr d1 = getAffineDimExpr(1, context);
AffineMap flatIndexMap =
AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * descriptor.layout.payloadFragmentCount + d1);
Value flatIndex = createOrFoldAffineApply(
state.rewriter, loc, flatIndexMap, ValueRange {*localizedMessageIndex, fragmentIndex}, targetClass.op);
FailureOr<SmallVector<OpFoldResult, 4>> fragmentOffsets =
buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc);
@@ -5618,8 +5600,8 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
return failure();
}
FailureOr<Value> publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResult))
FailureOr<unsigned> publicationResultIndex = appendScalarPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResultIndex))
return failure();
int64_t fragmentElementCount = fragmentType.getNumElements();
@@ -5657,7 +5639,7 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
originalOutput,
sourceClass.id,
ProducerKey {peer, resultIndex},
*publicationResult,
*publicationResultIndex,
static_cast<int64_t>(runIndex),
static_cast<int64_t>(runIndex) * fragmentElementCount,
SmallVector<int64_t, 4>(*offsets),
@@ -5711,8 +5693,8 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
if (fragmentType == originalOutput.getType())
return false;
FailureOr<Value> publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResult))
FailureOr<unsigned> publicationResultIndex = appendBatchPublicationResult(state, sourceClass, packed, loc);
if (failed(publicationResultIndex))
return failure();
if (packedType != fragmentType) {
@@ -5764,7 +5746,7 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
originalOutput,
sourceClass.id,
key,
*publicationResult,
*publicationResultIndex,
static_cast<int64_t>(fragmentIndex),
static_cast<int64_t>(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload,
SmallVector<int64_t, 4>(*offsets),
@@ -5787,18 +5769,26 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
SmallVector<Value, 8> outputs;
outputs.reserve(byOutput.size());
for (const auto& entry : byOutput)
outputs.push_back(entry.first);
llvm::sort(outputs, [](Value lhs, Value rhs) {
return reinterpret_cast<uintptr_t>(lhs.getAsOpaquePointer())
< reinterpret_cast<uintptr_t>(rhs.getAsOpaquePointer());
});
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
if (!returnOp)
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
DenseSet<Value> seenOutputs;
for (Value returned : returnOp.getOperands()) {
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
continue;
outputs.push_back(returned);
}
if (outputs.size() != byOutput.size())
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
for (Value originalOutput : outputs) {
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
return state.func.emitError(
"projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result");
}
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
if (!resultType || !resultType.hasStaticShape())
return state.func.emitError("projected host output must have static ranked tensor type");
@@ -5806,13 +5796,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
const PendingProjectedHostOutputFragment* rhs) {
if (lhs->publicationValue != rhs->publicationValue)
return reinterpret_cast<uintptr_t>(lhs->publicationValue.getAsOpaquePointer())
< reinterpret_cast<uintptr_t>(rhs->publicationValue.getAsOpaquePointer());
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
if (lhs->sourceClass != rhs->sourceClass)
return lhs->sourceClass < rhs->sourceClass;
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
return lhs->publicationResultIndex < rhs->publicationResultIndex;
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
return std::lexicographical_compare(lhs->offsets.begin(),
lhs->offsets.end(),
rhs->offsets.begin(),
@@ -5821,7 +5810,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
state.rewriter.setInsertionPoint(returnOp);
Location loc = fragments.front()->loc;
SmallVector<Value, 16> reconciliatorOperands;
SmallVector<Value, 16> blueprintOperands;
SmallVector<int64_t, 16> fragmentOperandIndices;
SmallVector<int64_t, 16> fragmentSourceOffsets;
SmallVector<int64_t, 64> flatOffsets;
@@ -5830,12 +5819,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
DenseMap<Value, int64_t> operandIndicesByValue;
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
Value operand = fragmentRecord->publicationValue;
if (fragmentRecord->sourceClass >= state.classes.size())
return state.func.emitError("projected host output fragment references an invalid source class");
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
<< " sourceClass=" << sourceClass.id
<< " resultIndex=" << fragmentRecord->publicationResultIndex
<< " resultCount=" << sourceClass.op->getNumResults();
}
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
auto [operandIt, inserted] =
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(reconciliatorOperands.size()));
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
if (inserted)
reconciliatorOperands.push_back(operand);
blueprintOperands.push_back(operand);
fragmentOperandIndices.push_back(operandIt->second);
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
llvm::append_range(flatOffsets, fragmentRecord->offsets);
@@ -5847,12 +5847,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
}
if (reconciliatorOperands.empty())
if (blueprintOperands.empty())
return state.func.emitError("missing projected host output fragments");
Value input = reconciliatorOperands.front();
ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front();
auto reconciliator = spatial::SpatReconciliatorOp::create(
Value input = blueprintOperands.front();
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
auto blueprint = spatial::SpatBlueprintOp::create(
state.rewriter,
loc,
resultType,
@@ -5870,7 +5870,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
state.rewriter.getStringAttr("disjoint"),
state.rewriter.getStringAttr("complete"));
state.hostReplacements[originalOutput] = reconciliator.getOutput();
state.hostReplacements[originalOutput] = blueprint.getOutput();
}
return success();
@@ -6284,6 +6284,32 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state,
mapper.map(operand, *localized);
}
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
auto remapFoldResult = [&](OpFoldResult value) -> OpFoldResult {
if (auto mappedValue = dyn_cast_if_present<Value>(value))
return mapper.lookupOrDefault(mappedValue);
return value;
};
SmallVector<OpFoldResult, 4> offsets;
SmallVector<OpFoldResult, 4> sizes;
SmallVector<OpFoldResult, 4> strides;
offsets.reserve(extract.getMixedOffsets().size());
sizes.reserve(extract.getMixedSizes().size());
strides.reserve(extract.getMixedStrides().size());
llvm::append_range(offsets, llvm::map_range(extract.getMixedOffsets(), remapFoldResult));
llvm::append_range(sizes, llvm::map_range(extract.getMixedSizes(), remapFoldResult));
llvm::append_range(strides, llvm::map_range(extract.getMixedStrides(), remapFoldResult));
auto resultType = cast<RankedTensorType>(extract.getType());
Value localizedSource = mapper.lookupOrDefault(extract.getSource());
Value localizedExtract = extractStaticSliceOrIdentity(
state.rewriter, extract.getLoc(), localizedSource, resultType, offsets, sizes, strides);
mapper.map(extract.getResult(), localizedExtract);
continue;
}
Operation* cloned = state.rewriter.clone(op, mapper);
if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper)))
return failure();
@@ -6350,18 +6376,20 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
if (failed(localizedIv))
return failure();
Value iv = *localizedIv;
Value lowerBound =
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]);
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]);
Value tripCount =
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]);
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineMap normalizedMap =
AffineMap::get(/*dimCount=*/1,
/*symbolCount=*/0,
(d0 - replacement.layout.loopLowerBounds[index]).floorDiv(replacement.layout.loopSteps[index]));
Value normalized =
createOrFoldAffineApply(state.rewriter, extract.getLoc(), normalizedMap, ValueRange {iv}, targetClass.op);
Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult();
if (replacement.layout.loopSteps[index] != 1)
normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult();
linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult();
linearizedIndex =
arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult();
AffineExpr d1 = getAffineDimExpr(1, context);
AffineMap linearizedMap = AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.loopTripCounts[index] + d1);
linearizedIndex = createOrFoldAffineApply(
state.rewriter, extract.getLoc(), linearizedMap, ValueRange {linearizedIndex, normalized}, targetClass.op);
}
return linearizedIndex;
};
@@ -6386,12 +6414,16 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
if (failed(localProjectionSlotIndex))
return failure();
Value fragmentsPerLogicalSlot =
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot);
Value base =
arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot)
.getResult();
return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult();
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr d1 = getAffineDimExpr(1, context);
AffineMap packedIndexMap = AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.fragmentsPerLogicalSlot + d1);
return createOrFoldAffineApply(state.rewriter,
extract.getLoc(),
packedIndexMap,
ValueRange {*localProjectionSlotIndex, intraSlotFragmentIndex},
targetClass.op);
};
FailureOr<Value> packedFragmentIndex = computeProjectedPayloadFragmentIndex();
@@ -6445,18 +6477,18 @@ LogicalResult localizeCapturesInOperationTree(MaterializerState& state,
localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper);
if (failed(localized)) {
InFlightDiagnostic diagnostic = targetClass.op->emitError(
"RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand");
"failed to localize cloned scheduled-body operand");
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
<< "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass)
<< "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\"";
diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation";
attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR");
attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands");
attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain");
attachMaterializerOperationPrintNote(diagnostic, nestedOp, "offending nested operation IR");
attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "offending nested operation operands");
attachMaterializerParentChainNote(diagnostic, nestedOp, "offending nested operation parent chain");
attachMaterializerValueOriginNote(diagnostic, current, "offending operand");
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op");
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op");
attachMaterializedClassBodySummary(diagnostic, targetClass);
return WalkResult::interrupt();
}
@@ -6505,7 +6537,7 @@ LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, Materia
"final scheduled body capture localization found an unsupported external non-tensor operand");
if (failed(localized)) {
InFlightDiagnostic diagnostic = targetClass.op->emitError(
"RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand");
"failed to localize final scheduled-body operand");
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
@@ -1,128 +0,0 @@
--- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000
+++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000
@@ -4112,104 +4112,8 @@
Value originalOutput,
Location loc);
-FailureOr<SmallVector<OpFoldResult, 4>> rematerializeProjectionIndexListForBatchHostOutput(
- MaterializerState& state,
- MaterializedClass& sourceClass,
- ArrayRef<OpFoldResult> values,
- IRMapping& mapper,
- Location loc) {
- SmallVector<OpFoldResult, 4> localized;
- localized.reserve(values.size());
- for (OpFoldResult value : values) {
- FailureOr<OpFoldResult> remapped =
- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper);
- if (failed(remapped))
- return failure();
- localized.push_back(*remapped);
- }
- return localized;
-}
-
-LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state,
- MaterializedClass& sourceClass,
- Value originalOutput,
- Value payload,
- Value destination,
- ArrayRef<ProducerKey> keys,
- Location loc) {
- auto originalResult = dyn_cast<OpResult>(originalOutput);
- if (!originalResult)
- return failure();
-
- auto sourceBatch = dyn_cast_or_null<SpatComputeBatch>(originalResult.getOwner());
- if (!sourceBatch || sourceBatch.getNumResults() == 0)
- return failure();
-
- FailureOr<tensor::ParallelInsertSliceOp> projection =
- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber());
- if (failed(projection))
- return failure();
-
- auto sourceLaneArg = sourceBatch.getLaneArgument();
- if (!sourceLaneArg)
- return failure();
-
- auto materializedBatch = dyn_cast<SpatScheduledComputeBatch>(sourceClass.op);
- if (!materializedBatch)
- return failure();
-
- auto materializedLaneArg = materializedBatch.getLaneArgument();
- if (!materializedLaneArg)
- return failure();
-
- if (keys.size() != sourceClass.cpus.size())
- return failure();
-
- SmallVector<int64_t, 8> logicalLanes;
- logicalLanes.reserve(keys.size());
- for (ProducerKey key : keys) {
- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber())
- return failure();
- logicalLanes.push_back(key.instance.laneStart);
- }
-
- IRMapping mapper;
- Value logicalLane = createIndexedIndexValue(state,
- sourceClass.op,
- ArrayRef<int64_t>(logicalLanes),
- *materializedLaneArg,
- loc,
- static_cast<int64_t>(sourceClass.cpus.size()),
- /*allowExhaustiveTiledSearch=*/false);
- mapper.map(*sourceLaneArg, logicalLane);
-
- FailureOr<SmallVector<OpFoldResult, 4>> offsets =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedOffsets(), mapper, loc);
- if (failed(offsets))
- return failure();
- FailureOr<SmallVector<OpFoldResult, 4>> sizes =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedSizes(), mapper, loc);
- if (failed(sizes))
- return failure();
- FailureOr<SmallVector<OpFoldResult, 4>> strides =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedStrides(), mapper, loc);
- if (failed(strides))
- return failure();
-
- tensor::ParallelInsertSliceOp::create(
- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides);
- return success();
-}
-
LogicalResult
-setHostOutputValue(MaterializerState& state,
- MaterializedClass& sourceClass,
- Value originalOutput,
- Value payload,
- ArrayRef<ProducerKey> keys = {}) {
+setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
if (resultIt == sourceClass.hostOutputToResultIndex.end())
return sourceClass.op->emitError("missing host result slot for materialized output")
@@ -4253,10 +4157,6 @@
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
- if (succeeded(createProjectionAwareBatchHostInsert(
- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc())))
- return success();
-
createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg);
return success();
}
@@ -4276,7 +4176,7 @@
MaterializedClass& ownerClass = state.classes[ownerIt->second];
if (sourceClass.id == ownerClass.id)
- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys);
+ return setHostOutputValue(state, ownerClass, originalOutput, payload);
// Keep the old deadlock-free communication discipline: only scalar-to-scalar
// host-owner forwarding is introduced here. Batch host publication remains on
@@ -354,13 +354,13 @@ public:
void runOnOperation() override {
func::FuncOp func = getOperation();
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
func.emitOpError("logical Spatial graph verification failed at the start of MergeComputeNodes");
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(func);
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
func.emitOpError("logical Spatial graph verification failed after trivial merge simplification");
signalPassFailure();
return;
}
@@ -378,7 +378,7 @@ public:
return;
}
if (failed(verifyScheduledSpatialInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
func.emitOpError("scheduled Spatial verification failed after merge materialization");
signalPassFailure();
return;
}