Bose
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-26 17:45:27 +02:00
parent 984f362623
commit 78e97f9fd8
23 changed files with 513 additions and 17489 deletions
@@ -163,6 +163,38 @@ Value extractAxisSlice(
.getResult();
}
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto sourceType = cast<RankedTensorType>(source.getType());
size_t rank = static_cast<size_t>(sourceType.getRank());
bool isIdentitySlice =
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
&& strides.size() == rank;
if (isIdentitySlice) {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
isIdentitySlice = false;
break;
}
}
}
if (isIdentitySlice)
return source;
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
@@ -105,6 +105,14 @@ llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPer
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> sizes,
llvm::ArrayRef<mlir::OpFoldResult> strides);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
@@ -44,17 +44,17 @@ static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, R
return it->second;
}
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
Value physicalValue) {
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!logicalType)
return reconciliator.emitOpError("requires ranked logical output type"), failure();
return blueprint.emitOpError("requires ranked logical output type"), failure();
RowStripPhysicalValue value;
value.physicalValue = physicalValue;
value.logicalType = logicalType;
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
value.indexMap = reconciliator.getIndexMap().str();
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
value.indexMap = blueprint.getIndexMap().str();
return value;
}
@@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
return true;
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
signalPassFailure();
return false;
};
@@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
});
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
planOp,
@@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
if (failed(rowStripValue)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *rowStripValue;
rowStripValues[blueprint.getResult()] = *rowStripValue;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
eraseAfterLowering.insert(blueprint);
continue;
}
rewriter.setInsertionPoint(planOp);
@@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
});
if (outputReconciliator == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
if (outputBlueprint == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
signalPassFailure();
return;
}
@@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
if (failed(output)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *output;
rowStripValues[blueprint.getResult()] = *output;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
eraseAfterLowering.insert(blueprint);
continue;
}
@@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
}
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
if (failed(rowStripValue)) {
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
signalPassFailure();
return;
}
@@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
rewriter.replaceOp(materializeOp, *dense);
continue;
}
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
continue;
}
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
signalPassFailure();
return;
}
if (!eraseAfterLowering.contains(reconciliatorOp)) {
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
if (!eraseAfterLowering.contains(blueprintOp)) {
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
signalPassFailure();
return;
}
@@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
return;
if (isa<spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatBlueprintOp,
spatial::SpatMaterializeLayoutOp>(op)
|| op->getDialect()->getNamespace() == "onnx") {
op->emitOpError("operation must not remain after LowerSpatialPlans");
@@ -46,9 +46,9 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
SmallVector<spatial::SpatBlueprintOp> blueprints(funcOp.getOps<spatial::SpatBlueprintOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty()
|| !materializers.empty()) {
return;
}
@@ -160,7 +160,7 @@ void ONNXToSpatialPass::runOnOperation() {
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
@@ -181,7 +181,7 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
moduleOp.emitError("logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
@@ -199,7 +199,7 @@ void ONNXToSpatialPass::runOnOperation() {
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
moduleOp.emitError("logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
@@ -214,7 +214,7 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
@@ -15,7 +15,7 @@ namespace onnx_mlir {
namespace {
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
constexpr StringLiteral kPhaseMarker = "phase-check";
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
@@ -114,14 +114,14 @@ void verifyScheduledInputs(ComputeOpTy compute,
}
template <typename ComputeOpTy>
void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute,
void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute,
pim::CappedDiagnosticReporter& diagnostics) {
compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) {
std::optional<StringRef> mode = reconciliator.getMode();
compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) {
std::optional<StringRef> mode = blueprint.getMode();
if (!mode || *mode != "fragment_assembly")
return;
diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) {
illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization");
diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) {
illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization");
});
});
}
@@ -133,7 +133,7 @@ void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter
spatial::SpatGraphComputeBatch,
spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatBlueprintOp,
spatial::SpatMaterializeLayoutOp>(&op)) {
continue;
}
@@ -203,11 +203,11 @@ LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
verifyScheduledTopLevelOps(funcOp, diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics);
verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics);
}
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics);
verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics);
}
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
@@ -2242,8 +2242,8 @@ static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
Value bTile = tensor::ExtractSliceOp::create(
rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides);
Value bTile = extractStaticSliceOrIdentity(
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
reduceYielded.push_back(
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
@@ -2912,8 +2912,13 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
Value bTile = tensor::ExtractSliceOp::create(
rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2));
Value bTile = extractStaticSliceOrIdentity(rewriter,
reduceLoc,
args.weights.front(),
paddedWeightTileType,
bOffsets,
bSizes,
getUnitStrides(rewriter, 2));
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
reduceYielded.push_back(
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
@@ -285,9 +285,8 @@ static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
.getResult();
Value bTile = extractStaticSliceOrIdentity(
rewriter, loc, args.weights.front(), bTileType, bOffsets, bSizes, unitStrides);
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -90,10 +90,10 @@ static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
return SelectedLayout::NchwRowStrip;
}
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
static spatial::SpatBlueprintOp insertRowStripBlueprint(IRRewriter& rewriter, Value value) {
auto outputType = cast<RankedTensorType>(value.getType());
auto [offsets, sizes] = buildRowStripMetadata(outputType);
return spatial::SpatReconciliatorOp::create(rewriter,
return spatial::SpatBlueprintOp::create(rewriter,
value.getLoc(),
outputType,
value,
@@ -189,12 +189,12 @@ struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass,
continue;
rewriter.setInsertionPointAfter(&op);
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
auto blueprint = insertRowStripBlueprint(rewriter, producedValue);
rewriter.replaceAllUsesExcept(producedValue, blueprint.getResult(), blueprint);
materializeDenseUses(rewriter, blueprint.getResult(), layouts);
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
getOperation().emitError("logical Spatial graph verification failed after SpatialLayoutPlanning");
signalPassFailure();
}
}
@@ -181,32 +181,32 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
for (OpOperand& use : result.getUses()) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
return failure();
std::optional<StringRef> mode = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
std::optional<StringRef> mode = blueprint.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
return failure();
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
return failure();
unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
auto hostResultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!hostResultType || !hostResultType.hasStaticShape())
return failure();
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *stridesAttr;
int64_t rank = hostResultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
fragmentOperands.size(),
operandIndices,
@@ -379,34 +379,34 @@ static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
}
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
spatial::SpatBlueprintOp blueprint,
Value hostTarget,
ArrayRef<OpFoldResult> baseOffsets,
IRMapping& mapper) {
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return reconciliator.emitOpError(
return blueprint.emitOpError(
"fragment assembly lowering requires explicit operand indices and unit strides");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
if (!sourceOffsetsAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets");
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
fragmentOperands.size(),
operandIndices,
@@ -423,14 +423,14 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
}
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
@@ -440,11 +440,11 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
if (failed(extractOffsets))
return failure();
fragment = tensor::ExtractSliceOp::create(rewriter,
reconciliator.getLoc(),
blueprint.getLoc(),
source,
getStaticIndexAttrs(rewriter, *extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
@@ -452,11 +452,11 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
}
hostTarget = tensor::InsertSliceOp::create(rewriter,
reconciliator.getLoc(),
blueprint.getLoc(),
fragment,
hostTarget,
buildFragmentOffsets(rewriter,
reconciliator.getLoc(),
blueprint.getLoc(),
baseOffsets,
fragmentOffsets,
mapper),
@@ -585,13 +585,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
std::optional<StringRef> modeAttr = blueprint.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
for (Operation* user : reconciliator.getOutput().getUsers()) {
for (Operation* user : blueprint.getOutput().getUsers()) {
if (!isa<tensor::ParallelInsertSliceOp>(user))
return reconciliator.emitOpError(
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
return blueprint.emitOpError(
"fragment assembly blueprint lowering expects only tensor.parallel_insert_slice users");
}
continue;
}
@@ -653,12 +653,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
if (auto reconciliator =
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (auto blueprint =
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
std::optional<StringRef> modeAttr = blueprint.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
reconciliator,
blueprint,
hostTarget,
insertSlice.getMixedOffsets(),
mapper);
+7 -7
View File
@@ -73,7 +73,7 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
}
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatBlueprintOp blueprint,
int64_t resultRank,
size_t operandCount,
ArrayRef<int64_t> operandIndices,
@@ -82,19 +82,19 @@ LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reco
ArrayRef<int64_t> flatSizes,
ArrayRef<int64_t> flatStrides) {
if (operandIndices.size() != sourceOffsets.size())
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
return blueprint.emitOpError("fragment assembly operand index and source offset counts must match");
if (flatOffsets.size() != flatSizes.size())
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
return blueprint.emitOpError("fragment assembly offset and size arrays must have matching lengths");
if (flatStrides.size() != flatOffsets.size())
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
return blueprint.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
return blueprint.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
return blueprint.emitOpError("fragment assembly operand index is out of range");
if (sourceOffsets[fragmentIndex] < 0)
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
return blueprint.emitOpError("fragment assembly source offsets must be nonnegative");
}
return success();
+2 -2
View File
@@ -9,7 +9,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir::spatial {
class SpatReconciliatorOp;
class SpatBlueprintOp;
}
namespace onnx_mlir {
@@ -36,7 +36,7 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator,
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatBlueprintOp blueprint,
int64_t resultRank,
size_t operandCount,
llvm::ArrayRef<int64_t> operandIndices,
@@ -43,31 +43,31 @@ static Value createStaticHostTargetOffset(IRRewriter& rewriter,
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
}
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
spatial::SpatBlueprintOp blueprint,
IRMapping& mapping) {
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<StringRef> modeAttr = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
std::optional<StringRef> modeAttr = blueprint.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|| !fragmentStridesAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
return blueprint.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
fragmentOperands.size(),
operandIndices,
@@ -77,7 +77,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
flatStrides)))
return failure();
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
@@ -86,7 +86,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
@@ -94,21 +94,21 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t fragmentBytes =
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
reconciliator.getOperation(),
blueprint.getOperation(),
fragmentBytes,
"fragment assembly host copy size");
if (failed(sizeAttr))
return failure();
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
reconciliator,
blueprint,
"fragment assembly device source offset");
if (failed(deviceSourceOffsetBytes))
return failure();
@@ -116,7 +116,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
rewriter.getInsertionBlock()->getParentOp(),
static_cast<int64_t>(*deviceSourceOffsetBytes));
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
blueprint.getLoc(),
currentOutput.getType(),
hostTargetOffset,
deviceSourceOffset,
@@ -230,13 +230,13 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) {
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
std::optional<StringRef> modeAttr = blueprint.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
auto lowered = lowerFragmentAssemblyBlueprint(rewriter, blueprint, mapping);
if (failed(lowered))
return false;
mapping.map(reconciliator.getOutput(), *lowered);
mapping.map(blueprint.getOutput(), *lowered);
continue;
}
}
+4 -4
View File
@@ -31,11 +31,11 @@ static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t ran
return strides;
}
struct LowerFragmentAssemblyReconciliatorPattern
: OpConversionPattern<spatial::SpatReconciliatorOp> {
struct LowerFragmentAssemblyBlueprintPattern
: OpConversionPattern<spatial::SpatBlueprintOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
LogicalResult matchAndRewrite(spatial::SpatBlueprintOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
std::optional<StringRef> modeAttr = op.getMode();
@@ -125,7 +125,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
patterns.add<LowerFragmentAssemblyBlueprintPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -149,36 +149,36 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
};
}
static FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>>
static FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>>
analyzeTopLevelFragmentAssemblyUses(Value value) {
SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4> uses;
SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4> uses;
for (OpOperand& use : value.getUses()) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
return failure();
std::optional<StringRef> mode = reconciliator.getMode();
std::optional<StringRef> mode = blueprint.getMode();
if (!mode || *mode != "fragment_assembly")
return failure();
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
return failure();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
return failure();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
if (failed(validateFragmentAssemblyMetadata(reconciliator,
SmallVector<Value> fragmentOperands {blueprint.getInput()};
llvm::append_range(fragmentOperands, blueprint.getFragments());
if (failed(validateFragmentAssemblyMetadata(blueprint,
resultType.getRank(),
fragmentOperands.size(),
*operandIndicesAttr,
*sourceOffsetsAttr,
reconciliator.getFragmentOffsets(),
reconciliator.getFragmentSizes(),
blueprint.getFragmentOffsets(),
blueprint.getFragmentSizes(),
*stridesAttr)))
return failure();
uses.emplace_back(reconciliator, use.getOperandNumber());
uses.emplace_back(blueprint, use.getOperandNumber());
}
return uses;
}
@@ -593,7 +593,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
}
FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>> fragmentAssemblyUses =
FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>> fragmentAssemblyUses =
analyzeTopLevelFragmentAssemblyUses(producedValue);
if (succeeded(fragmentAssemblyUses)) {
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
@@ -603,35 +603,35 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) {
for (auto [blueprint, operandNumber] : *fragmentAssemblyUses) {
rewriter.setInsertionPointAfterValue(storedValue);
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
reconciliator.emitOpError(
blueprint.emitOpError(
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
return ReturnPathLoweringResult::Failure;
}
size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
size_t returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!outputType || !resultType || !resultType.hasStaticShape()) {
reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs");
blueprint.emitOpError("fragment assembly lowering requires static ranked host outputs");
return ReturnPathLoweringResult::Failure;
}
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *stridesAttr;
int64_t rank = resultType.getRank();
if (failed(validateFragmentAssemblyMetadata(reconciliator,
if (failed(validateFragmentAssemblyMetadata(blueprint,
rank,
1 + reconciliator.getFragments().size(),
1 + blueprint.getFragments().size(),
operandIndices,
sourceOffsets,
flatOffsets,
@@ -647,7 +647,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1) {
reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
blueprint.emitOpError("fragment assembly lowering only supports unit strides");
return ReturnPathLoweringResult::Failure;
}
fragmentOffsets.push_back(flatOffsets[flatIndex]);
@@ -684,7 +684,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
outputTensor =
pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
blueprint.getLoc(),
outputTensor.getType(),
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
@@ -698,7 +698,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (failedChunk)
return ReturnPathLoweringResult::Failure;
}
markOpToRemove(reconciliator.getOperation());
markOpToRemove(blueprint.getOperation());
}
return ReturnPathLoweringResult::Handled;
}
@@ -813,11 +813,11 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return;
}
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> mode = reconciliator.getMode();
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
std::optional<StringRef> mode = blueprint.getMode();
if (mode && *mode == "fragment_assembly") {
markOpToRemove(reconciliator.getOperation());
for (Value operand : reconciliator->getOperands())
markOpToRemove(blueprint.getOperation());
for (Value operand : blueprint->getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
@@ -203,7 +203,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
funcOp.emitOpError(
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
"scheduled Spatial verification failed at the start of SpatialToPim");
signalPassFailure();
return;
}