@@ -163,6 +163,38 @@ Value extractAxisSlice(
|
|||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Value source,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||||
|
|
||||||
|
bool isIdentitySlice =
|
||||||
|
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||||
|
&& strides.size() == rank;
|
||||||
|
if (isIdentitySlice) {
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||||
|
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||||
|
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||||
|
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||||
|
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
||||||
|
isIdentitySlice = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isIdentitySlice)
|
||||||
|
return source;
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value insertStaticSlice(
|
Value insertStaticSlice(
|
||||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
|||||||
@@ -105,6 +105,14 @@ llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPer
|
|||||||
mlir::Value extractAxisSlice(
|
mlir::Value extractAxisSlice(
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
|
|
||||||
|
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||||
|
|
||||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::Value source,
|
mlir::Value source,
|
||||||
|
|||||||
@@ -44,17 +44,17 @@ static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, R
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
|
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
|
||||||
Value physicalValue) {
|
Value physicalValue) {
|
||||||
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!logicalType)
|
if (!logicalType)
|
||||||
return reconciliator.emitOpError("requires ranked logical output type"), failure();
|
return blueprint.emitOpError("requires ranked logical output type"), failure();
|
||||||
RowStripPhysicalValue value;
|
RowStripPhysicalValue value;
|
||||||
value.physicalValue = physicalValue;
|
value.physicalValue = physicalValue;
|
||||||
value.logicalType = logicalType;
|
value.logicalType = logicalType;
|
||||||
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
|
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
|
||||||
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
|
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
|
||||||
value.indexMap = reconciliator.getIndexMap().str();
|
value.indexMap = blueprint.getIndexMap().str();
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||||
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||||
return true;
|
return true;
|
||||||
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
|
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||||
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||||
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||||
});
|
});
|
||||||
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
|
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
|
||||||
rewriter.setInsertionPoint(planOp);
|
rewriter.setInsertionPoint(planOp);
|
||||||
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||||
planOp,
|
planOp,
|
||||||
@@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
|
||||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
|
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
|
||||||
if (failed(rowStripValue)) {
|
if (failed(rowStripValue)) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
rowStripValues[reconciliator.getResult()] = *rowStripValue;
|
rowStripValues[blueprint.getResult()] = *rowStripValue;
|
||||||
eraseAfterLowering.insert(planOp);
|
eraseAfterLowering.insert(planOp);
|
||||||
eraseAfterLowering.insert(reconciliator);
|
eraseAfterLowering.insert(blueprint);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPoint(planOp);
|
rewriter.setInsertionPoint(planOp);
|
||||||
@@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
|
|
||||||
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||||
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||||
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||||
});
|
});
|
||||||
if (outputReconciliator == planOp.getResult().getUsers().end()) {
|
if (outputBlueprint == planOp.getResult().getUsers().end()) {
|
||||||
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
|
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
|
||||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
|
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
|
||||||
if (failed(output)) {
|
if (failed(output)) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
rowStripValues[reconciliator.getResult()] = *output;
|
rowStripValues[blueprint.getResult()] = *output;
|
||||||
eraseAfterLowering.insert(planOp);
|
eraseAfterLowering.insert(planOp);
|
||||||
eraseAfterLowering.insert(reconciliator);
|
eraseAfterLowering.insert(blueprint);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
}
|
}
|
||||||
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||||
if (failed(rowStripValue)) {
|
if (failed(rowStripValue)) {
|
||||||
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
|
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
rewriter.replaceOp(materializeOp, *dense);
|
rewriter.replaceOp(materializeOp, *dense);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
|
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
|
||||||
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
|
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
|
||||||
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
|
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
|
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
|
||||||
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
|
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!eraseAfterLowering.contains(reconciliatorOp)) {
|
if (!eraseAfterLowering.contains(blueprintOp)) {
|
||||||
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
|
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
|||||||
return;
|
return;
|
||||||
if (isa<spatial::SpatConv2DPlanOp,
|
if (isa<spatial::SpatConv2DPlanOp,
|
||||||
spatial::SpatReluPlanOp,
|
spatial::SpatReluPlanOp,
|
||||||
spatial::SpatReconciliatorOp,
|
spatial::SpatBlueprintOp,
|
||||||
spatial::SpatMaterializeLayoutOp>(op)
|
spatial::SpatMaterializeLayoutOp>(op)
|
||||||
|| op->getDialect()->getNamespace() == "onnx") {
|
|| op->getDialect()->getNamespace() == "onnx") {
|
||||||
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||||
|
|||||||
@@ -46,9 +46,9 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
||||||
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
||||||
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
||||||
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
|
SmallVector<spatial::SpatBlueprintOp> blueprints(funcOp.getOps<spatial::SpatBlueprintOp>());
|
||||||
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
||||||
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty()
|
||||||
|| !materializers.empty()) {
|
|| !materializers.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -160,7 +160,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
|
moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -181,7 +181,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
|
moduleOp.emitError("logical Spatial graph verification failed after weight annotation");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -199,7 +199,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
|
moduleOp.emitError("logical Spatial graph verification failed before post rewrites");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -214,7 +214,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
populateEmptyFunction(*entryFunc);
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
|
moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
|
constexpr StringLiteral kPhaseMarker = "phase-check";
|
||||||
|
|
||||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
func.walk([&](Operation* op) {
|
func.walk([&](Operation* op) {
|
||||||
@@ -114,14 +114,14 @@ void verifyScheduledInputs(ComputeOpTy compute,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename ComputeOpTy>
|
template <typename ComputeOpTy>
|
||||||
void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute,
|
void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) {
|
compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) {
|
||||||
std::optional<StringRef> mode = reconciliator.getMode();
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
if (!mode || *mode != "fragment_assembly")
|
if (!mode || *mode != "fragment_assembly")
|
||||||
return;
|
return;
|
||||||
diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) {
|
diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization");
|
illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -133,7 +133,7 @@ void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter
|
|||||||
spatial::SpatGraphComputeBatch,
|
spatial::SpatGraphComputeBatch,
|
||||||
spatial::SpatConv2DPlanOp,
|
spatial::SpatConv2DPlanOp,
|
||||||
spatial::SpatReluPlanOp,
|
spatial::SpatReluPlanOp,
|
||||||
spatial::SpatReconciliatorOp,
|
spatial::SpatBlueprintOp,
|
||||||
spatial::SpatMaterializeLayoutOp>(&op)) {
|
spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -203,11 +203,11 @@ LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
|||||||
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||||
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||||
verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics);
|
verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics);
|
||||||
}
|
}
|
||||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||||
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||||
verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics);
|
verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics);
|
||||||
}
|
}
|
||||||
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -2242,8 +2242,8 @@ static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
|
|||||||
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
|
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
|
||||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||||
Value bTile = tensor::ExtractSliceOp::create(
|
Value bTile = extractStaticSliceOrIdentity(
|
||||||
rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides);
|
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||||
reduceYielded.push_back(
|
reduceYielded.push_back(
|
||||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
|
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
|
||||||
@@ -2912,8 +2912,13 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
|||||||
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
|
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
|
||||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||||
Value bTile = tensor::ExtractSliceOp::create(
|
Value bTile = extractStaticSliceOrIdentity(rewriter,
|
||||||
rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2));
|
reduceLoc,
|
||||||
|
args.weights.front(),
|
||||||
|
paddedWeightTileType,
|
||||||
|
bOffsets,
|
||||||
|
bSizes,
|
||||||
|
getUnitStrides(rewriter, 2));
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||||
reduceYielded.push_back(
|
reduceYielded.push_back(
|
||||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
|
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
|
||||||
|
|||||||
@@ -285,9 +285,8 @@ static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
|||||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||||
Value bTile =
|
Value bTile = extractStaticSliceOrIdentity(
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
rewriter, loc, args.weights.front(), bTileType, bOffsets, bSizes, unitStrides);
|
||||||
.getResult();
|
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||||
|
|
||||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
|||||||
@@ -90,10 +90,10 @@ static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
|
|||||||
return SelectedLayout::NchwRowStrip;
|
return SelectedLayout::NchwRowStrip;
|
||||||
}
|
}
|
||||||
|
|
||||||
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
|
static spatial::SpatBlueprintOp insertRowStripBlueprint(IRRewriter& rewriter, Value value) {
|
||||||
auto outputType = cast<RankedTensorType>(value.getType());
|
auto outputType = cast<RankedTensorType>(value.getType());
|
||||||
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
||||||
return spatial::SpatReconciliatorOp::create(rewriter,
|
return spatial::SpatBlueprintOp::create(rewriter,
|
||||||
value.getLoc(),
|
value.getLoc(),
|
||||||
outputType,
|
outputType,
|
||||||
value,
|
value,
|
||||||
@@ -189,12 +189,12 @@ struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass,
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(&op);
|
rewriter.setInsertionPointAfter(&op);
|
||||||
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
|
auto blueprint = insertRowStripBlueprint(rewriter, producedValue);
|
||||||
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
|
rewriter.replaceAllUsesExcept(producedValue, blueprint.getResult(), blueprint);
|
||||||
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
|
materializeDenseUses(rewriter, blueprint.getResult(), layouts);
|
||||||
}
|
}
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
|
getOperation().emitError("logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,32 +181,32 @@ analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResu
|
|||||||
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
|
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
|
||||||
|
|
||||||
for (OpOperand& use : result.getUses()) {
|
for (OpOperand& use : result.getUses()) {
|
||||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
|
||||||
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
|
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
|
||||||
return failure();
|
return failure();
|
||||||
std::optional<StringRef> mode = reconciliator.getMode();
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
||||||
return failure();
|
return failure();
|
||||||
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
|
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
unsigned returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
|
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||||
auto hostResultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!hostResultType || !hostResultType.hasStaticShape())
|
if (!hostResultType || !hostResultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||||
int64_t rank = hostResultType.getRank();
|
int64_t rank = hostResultType.getRank();
|
||||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
rank,
|
rank,
|
||||||
fragmentOperands.size(),
|
fragmentOperands.size(),
|
||||||
operandIndices,
|
operandIndices,
|
||||||
@@ -379,34 +379,34 @@ static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||||
spatial::SpatReconciliatorOp reconciliator,
|
spatial::SpatBlueprintOp blueprint,
|
||||||
Value hostTarget,
|
Value hostTarget,
|
||||||
ArrayRef<OpFoldResult> baseOffsets,
|
ArrayRef<OpFoldResult> baseOffsets,
|
||||||
IRMapping& mapper) {
|
IRMapping& mapper) {
|
||||||
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
||||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||||
|
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||||
if (!operandIndicesAttr || !fragmentStridesAttr)
|
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||||
return reconciliator.emitOpError(
|
return blueprint.emitOpError(
|
||||||
"fragment assembly lowering requires explicit operand indices and unit strides");
|
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
if (!sourceOffsetsAttr)
|
if (!sourceOffsetsAttr)
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires explicit source offsets");
|
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
int64_t rank = resultType.getRank();
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
rank,
|
rank,
|
||||||
fragmentOperands.size(),
|
fragmentOperands.size(),
|
||||||
operandIndices,
|
operandIndices,
|
||||||
@@ -423,14 +423,14 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
|||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
if (flatStrides[flatIndex] != 1)
|
if (flatStrides[flatIndex] != 1)
|
||||||
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentShape;
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
fragmentShape.reserve(rank);
|
fragmentShape.reserve(rank);
|
||||||
@@ -440,11 +440,11 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
|||||||
Value fragment = source;
|
Value fragment = source;
|
||||||
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||||
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||||
reconciliator, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||||
if (failed(extractOffsets))
|
if (failed(extractOffsets))
|
||||||
return failure();
|
return failure();
|
||||||
fragment = tensor::ExtractSliceOp::create(rewriter,
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||||
reconciliator.getLoc(),
|
blueprint.getLoc(),
|
||||||
source,
|
source,
|
||||||
getStaticIndexAttrs(rewriter, *extractOffsets),
|
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||||
getStaticIndexAttrs(rewriter, fragmentShape),
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
@@ -452,11 +452,11 @@ static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
||||||
reconciliator.getLoc(),
|
blueprint.getLoc(),
|
||||||
fragment,
|
fragment,
|
||||||
hostTarget,
|
hostTarget,
|
||||||
buildFragmentOffsets(rewriter,
|
buildFragmentOffsets(rewriter,
|
||||||
reconciliator.getLoc(),
|
blueprint.getLoc(),
|
||||||
baseOffsets,
|
baseOffsets,
|
||||||
fragmentOffsets,
|
fragmentOffsets,
|
||||||
mapper),
|
mapper),
|
||||||
@@ -585,13 +585,13 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
if (isa<spatial::SpatYieldOp>(op))
|
if (isa<spatial::SpatYieldOp>(op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
for (Operation* user : reconciliator.getOutput().getUsers()) {
|
for (Operation* user : blueprint.getOutput().getUsers()) {
|
||||||
if (!isa<tensor::ParallelInsertSliceOp>(user))
|
if (!isa<tensor::ParallelInsertSliceOp>(user))
|
||||||
return reconciliator.emitOpError(
|
return blueprint.emitOpError(
|
||||||
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
|
"fragment assembly blueprint lowering expects only tensor.parallel_insert_slice users");
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -653,12 +653,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
|
|
||||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||||
if (auto reconciliator =
|
if (auto blueprint =
|
||||||
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
|
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
|
||||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
||||||
reconciliator,
|
blueprint,
|
||||||
hostTarget,
|
hostTarget,
|
||||||
insertSlice.getMixedOffsets(),
|
insertSlice.getMixedOffsets(),
|
||||||
mapper);
|
mapper);
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
|
|||||||
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
|
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatBlueprintOp blueprint,
|
||||||
int64_t resultRank,
|
int64_t resultRank,
|
||||||
size_t operandCount,
|
size_t operandCount,
|
||||||
ArrayRef<int64_t> operandIndices,
|
ArrayRef<int64_t> operandIndices,
|
||||||
@@ -82,19 +82,19 @@ LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reco
|
|||||||
ArrayRef<int64_t> flatSizes,
|
ArrayRef<int64_t> flatSizes,
|
||||||
ArrayRef<int64_t> flatStrides) {
|
ArrayRef<int64_t> flatStrides) {
|
||||||
if (operandIndices.size() != sourceOffsets.size())
|
if (operandIndices.size() != sourceOffsets.size())
|
||||||
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
|
return blueprint.emitOpError("fragment assembly operand index and source offset counts must match");
|
||||||
if (flatOffsets.size() != flatSizes.size())
|
if (flatOffsets.size() != flatSizes.size())
|
||||||
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
return blueprint.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
||||||
if (flatStrides.size() != flatOffsets.size())
|
if (flatStrides.size() != flatOffsets.size())
|
||||||
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
return blueprint.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
||||||
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
||||||
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
return blueprint.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
||||||
|
|
||||||
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
||||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
||||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
return blueprint.emitOpError("fragment assembly operand index is out of range");
|
||||||
if (sourceOffsets[fragmentIndex] < 0)
|
if (sourceOffsets[fragmentIndex] < 0)
|
||||||
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
|
return blueprint.emitOpError("fragment assembly source offsets must be nonnegative");
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir::spatial {
|
namespace onnx_mlir::spatial {
|
||||||
class SpatReconciliatorOp;
|
class SpatBlueprintOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -36,7 +36,7 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
|
|||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatReconciliatorOp reconciliator,
|
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatBlueprintOp blueprint,
|
||||||
int64_t resultRank,
|
int64_t resultRank,
|
||||||
size_t operandCount,
|
size_t operandCount,
|
||||||
llvm::ArrayRef<int64_t> operandIndices,
|
llvm::ArrayRef<int64_t> operandIndices,
|
||||||
|
|||||||
@@ -43,31 +43,31 @@ static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
|||||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
||||||
spatial::SpatReconciliatorOp reconciliator,
|
spatial::SpatBlueprintOp blueprint,
|
||||||
IRMapping& mapping) {
|
IRMapping& mapping) {
|
||||||
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
|
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
|
|
||||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||||
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|
||||||
|| !fragmentStridesAttr)
|
|| !fragmentStridesAttr)
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
return blueprint.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
int64_t rank = resultType.getRank();
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
rank,
|
rank,
|
||||||
fragmentOperands.size(),
|
fragmentOperands.size(),
|
||||||
operandIndices,
|
operandIndices,
|
||||||
@@ -77,7 +77,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
|||||||
flatStrides)))
|
flatStrides)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
|
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
||||||
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
int64_t operandIndex = operandIndices[fragmentIndex];
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
|||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
if (flatStrides[flatIndex] != 1)
|
if (flatStrides[flatIndex] != 1)
|
||||||
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
fragmentElements *= flatSizes[flatIndex];
|
fragmentElements *= flatSizes[flatIndex];
|
||||||
}
|
}
|
||||||
@@ -94,21 +94,21 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
|||||||
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
if (!sourceType || !sourceType.hasStaticShape())
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
int64_t fragmentBytes =
|
int64_t fragmentBytes =
|
||||||
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
||||||
reconciliator.getOperation(),
|
blueprint.getOperation(),
|
||||||
fragmentBytes,
|
fragmentBytes,
|
||||||
"fragment assembly host copy size");
|
"fragment assembly host copy size");
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
|
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
|
||||||
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
||||||
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
||||||
reconciliator,
|
blueprint,
|
||||||
"fragment assembly device source offset");
|
"fragment assembly device source offset");
|
||||||
if (failed(deviceSourceOffsetBytes))
|
if (failed(deviceSourceOffsetBytes))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -116,7 +116,7 @@ static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
|||||||
rewriter.getInsertionBlock()->getParentOp(),
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
||||||
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
reconciliator.getLoc(),
|
blueprint.getLoc(),
|
||||||
currentOutput.getType(),
|
currentOutput.getType(),
|
||||||
hostTargetOffset,
|
hostTargetOffset,
|
||||||
deviceSourceOffset,
|
deviceSourceOffset,
|
||||||
@@ -230,13 +230,13 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule
|
|||||||
mapping.map(*weightArg, weight);
|
mapping.map(*weightArg, weight);
|
||||||
}
|
}
|
||||||
for (Operation& op : block.without_terminator()) {
|
for (Operation& op : block.without_terminator()) {
|
||||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
if (modeAttr && *modeAttr == "fragment_assembly") {
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
|
auto lowered = lowerFragmentAssemblyBlueprint(rewriter, blueprint, mapping);
|
||||||
if (failed(lowered))
|
if (failed(lowered))
|
||||||
return false;
|
return false;
|
||||||
mapping.map(reconciliator.getOutput(), *lowered);
|
mapping.map(blueprint.getOutput(), *lowered);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t ran
|
|||||||
return strides;
|
return strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LowerFragmentAssemblyReconciliatorPattern
|
struct LowerFragmentAssemblyBlueprintPattern
|
||||||
: OpConversionPattern<spatial::SpatReconciliatorOp> {
|
: OpConversionPattern<spatial::SpatBlueprintOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
|
LogicalResult matchAndRewrite(spatial::SpatBlueprintOp op,
|
||||||
OpAdaptor adaptor,
|
OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter& rewriter) const override {
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
std::optional<StringRef> modeAttr = op.getMode();
|
std::optional<StringRef> modeAttr = op.getMode();
|
||||||
@@ -125,7 +125,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
|
|||||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||||
raptor::populateWithGenerated(patterns);
|
raptor::populateWithGenerated(patterns);
|
||||||
populateTransposeLoweringPatterns(patterns);
|
populateTransposeLoweringPatterns(patterns);
|
||||||
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
|
patterns.add<LowerFragmentAssemblyBlueprintPattern>(patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -149,36 +149,36 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>>
|
static FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>>
|
||||||
analyzeTopLevelFragmentAssemblyUses(Value value) {
|
analyzeTopLevelFragmentAssemblyUses(Value value) {
|
||||||
SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4> uses;
|
SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4> uses;
|
||||||
for (OpOperand& use : value.getUses()) {
|
for (OpOperand& use : value.getUses()) {
|
||||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(use.getOwner());
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
|
||||||
if (!reconciliator || reconciliator->getParentOp() != reconciliator->getParentOfType<func::FuncOp>())
|
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
|
||||||
return failure();
|
return failure();
|
||||||
std::optional<StringRef> mode = reconciliator.getMode();
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
if (!mode || *mode != "fragment_assembly")
|
if (!mode || *mode != "fragment_assembly")
|
||||||
return failure();
|
return failure();
|
||||||
if (!reconciliator.getOutput().hasOneUse() || !isa<func::ReturnOp>(*reconciliator.getOutput().getUsers().begin()))
|
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||||
return failure();
|
return failure();
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
resultType.getRank(),
|
resultType.getRank(),
|
||||||
fragmentOperands.size(),
|
fragmentOperands.size(),
|
||||||
*operandIndicesAttr,
|
*operandIndicesAttr,
|
||||||
*sourceOffsetsAttr,
|
*sourceOffsetsAttr,
|
||||||
reconciliator.getFragmentOffsets(),
|
blueprint.getFragmentOffsets(),
|
||||||
reconciliator.getFragmentSizes(),
|
blueprint.getFragmentSizes(),
|
||||||
*stridesAttr)))
|
*stridesAttr)))
|
||||||
return failure();
|
return failure();
|
||||||
uses.emplace_back(reconciliator, use.getOperandNumber());
|
uses.emplace_back(blueprint, use.getOperandNumber());
|
||||||
}
|
}
|
||||||
return uses;
|
return uses;
|
||||||
}
|
}
|
||||||
@@ -593,7 +593,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<std::pair<spatial::SpatReconciliatorOp, size_t>, 4>> fragmentAssemblyUses =
|
FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>> fragmentAssemblyUses =
|
||||||
analyzeTopLevelFragmentAssemblyUses(producedValue);
|
analyzeTopLevelFragmentAssemblyUses(producedValue);
|
||||||
if (succeeded(fragmentAssemblyUses)) {
|
if (succeeded(fragmentAssemblyUses)) {
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
|
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||||
@@ -603,35 +603,35 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||||
for (auto [reconciliator, operandNumber] : *fragmentAssemblyUses) {
|
for (auto [blueprint, operandNumber] : *fragmentAssemblyUses) {
|
||||||
rewriter.setInsertionPointAfterValue(storedValue);
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = reconciliator.getFragmentSourceOffsets();
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
std::optional<ArrayRef<int64_t>> stridesAttr = reconciliator.getFragmentStrides();
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
|
||||||
reconciliator.emitOpError(
|
blueprint.emitOpError(
|
||||||
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
|
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t returnIndex = reconciliator.getOutput().getUses().begin()->getOperandNumber();
|
size_t returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||||
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
|
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
|
||||||
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
|
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
|
||||||
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
if (!outputType || !resultType || !resultType.hasStaticShape()) {
|
if (!outputType || !resultType || !resultType.hasStaticShape()) {
|
||||||
reconciliator.emitOpError("fragment assembly lowering requires static ranked host outputs");
|
blueprint.emitOpError("fragment assembly lowering requires static ranked host outputs");
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||||
int64_t rank = resultType.getRank();
|
int64_t rank = resultType.getRank();
|
||||||
if (failed(validateFragmentAssemblyMetadata(reconciliator,
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
rank,
|
rank,
|
||||||
1 + reconciliator.getFragments().size(),
|
1 + blueprint.getFragments().size(),
|
||||||
operandIndices,
|
operandIndices,
|
||||||
sourceOffsets,
|
sourceOffsets,
|
||||||
flatOffsets,
|
flatOffsets,
|
||||||
@@ -647,7 +647,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
int64_t flatIndex = fragmentIndex * rank + dim;
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
if (flatStrides[flatIndex] != 1) {
|
if (flatStrides[flatIndex] != 1) {
|
||||||
reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
@@ -684,7 +684,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
|
|
||||||
outputTensor =
|
outputTensor =
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
reconciliator.getLoc(),
|
blueprint.getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
||||||
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
||||||
@@ -698,7 +698,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
if (failedChunk)
|
if (failedChunk)
|
||||||
return ReturnPathLoweringResult::Failure;
|
return ReturnPathLoweringResult::Failure;
|
||||||
}
|
}
|
||||||
markOpToRemove(reconciliator.getOperation());
|
markOpToRemove(blueprint.getOperation());
|
||||||
}
|
}
|
||||||
return ReturnPathLoweringResult::Handled;
|
return ReturnPathLoweringResult::Handled;
|
||||||
}
|
}
|
||||||
@@ -813,11 +813,11 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
std::optional<StringRef> mode = reconciliator.getMode();
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
if (mode && *mode == "fragment_assembly") {
|
if (mode && *mode == "fragment_assembly") {
|
||||||
markOpToRemove(reconciliator.getOperation());
|
markOpToRemove(blueprint.getOperation());
|
||||||
for (Value operand : reconciliator->getOperands())
|
for (Value operand : blueprint->getOperands())
|
||||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -203,7 +203,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
func::FuncOp funcOp = *entryFunc;
|
func::FuncOp funcOp = *entryFunc;
|
||||||
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
|
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
|
||||||
funcOp.emitOpError(
|
funcOp.emitOpError(
|
||||||
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
|
"scheduled Spatial verification failed at the start of SpatialToPim");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -232,8 +232,8 @@ def SpatReluPlanOp : SpatOp<"relu_plan", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
def SpatBlueprintOp : SpatOp<"blueprint", []> {
|
||||||
let summary = "Logical-to-physical layout record or explicit fragment assembly";
|
let summary = "Blueprint for assembling logical tensors from published fragments";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
SpatTensor:$input,
|
SpatTensor:$input,
|
||||||
@@ -256,6 +256,7 @@ def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
|
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|||||||
return parser.getBuilder().getI32IntegerAttr(value);
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) {
|
||||||
|
StringRef value;
|
||||||
|
if (parser.parseKeyword(&value))
|
||||||
|
return failure();
|
||||||
|
attr = parser.getBuilder().getStringAttr(value);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||||
printer << "(";
|
printer << "(";
|
||||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||||
@@ -466,6 +474,131 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatBlueprintOp::print(OpAsmPrinter& printer) {
|
||||||
|
SmallVector<Value> operands {getInput()};
|
||||||
|
llvm::append_range(operands, getFragments());
|
||||||
|
|
||||||
|
printer << " fragments";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
printer << " layout " << getLogicalLayout();
|
||||||
|
printer << " physical " << getPhysicalLayout();
|
||||||
|
printer << " offsets ";
|
||||||
|
printCompressedIntegerList(printer, getFragmentOffsets());
|
||||||
|
printer << " sizes ";
|
||||||
|
printCompressedIntegerList(printer, getFragmentSizes());
|
||||||
|
printer << " map " << getIndexMap();
|
||||||
|
if (std::optional<StringRef> mode = getMode())
|
||||||
|
printer << " mode " << *mode;
|
||||||
|
if (std::optional<ArrayRef<int64_t>> operandIndices = getFragmentOperandIndices()) {
|
||||||
|
printer << " operandIndices ";
|
||||||
|
printCompressedIntegerList(printer, *operandIndices);
|
||||||
|
}
|
||||||
|
if (std::optional<ArrayRef<int64_t>> sourceOffsets = getFragmentSourceOffsets()) {
|
||||||
|
printer << " sourceOffsets ";
|
||||||
|
printCompressedIntegerList(printer, *sourceOffsets);
|
||||||
|
}
|
||||||
|
if (std::optional<ArrayRef<int64_t>> strides = getFragmentStrides()) {
|
||||||
|
printer << " strides ";
|
||||||
|
printCompressedIntegerList(printer, *strides);
|
||||||
|
}
|
||||||
|
if (std::optional<StringRef> conflictPolicy = getConflictPolicy())
|
||||||
|
printer << " conflict " << *conflictPolicy;
|
||||||
|
if (std::optional<StringRef> coveragePolicy = getCoveragePolicy())
|
||||||
|
printer << " coverage " << *coveragePolicy;
|
||||||
|
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||||
|
{getLogicalLayoutAttrName().getValue(),
|
||||||
|
getPhysicalLayoutAttrName().getValue(),
|
||||||
|
getFragmentOffsetsAttrName().getValue(),
|
||||||
|
getFragmentSizesAttrName().getValue(),
|
||||||
|
getIndexMapAttrName().getValue(),
|
||||||
|
getModeAttrName().getValue(),
|
||||||
|
getFragmentOperandIndicesAttrName().getValue(),
|
||||||
|
getFragmentSourceOffsetsAttrName().getValue(),
|
||||||
|
getFragmentStridesAttrName().getValue(),
|
||||||
|
getConflictPolicyAttrName().getValue(),
|
||||||
|
getCoveragePolicyAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren);
|
||||||
|
printer << " -> ";
|
||||||
|
printer.printType(getOutput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
||||||
|
SmallVector<Type> operandTypes;
|
||||||
|
Type outputType;
|
||||||
|
StringAttr logicalLayout;
|
||||||
|
StringAttr physicalLayout;
|
||||||
|
StringAttr indexMap;
|
||||||
|
StringAttr mode;
|
||||||
|
StringAttr conflictPolicy;
|
||||||
|
StringAttr coveragePolicy;
|
||||||
|
SmallVector<int64_t> fragmentOffsets;
|
||||||
|
SmallVector<int64_t> fragmentSizes;
|
||||||
|
SmallVector<int64_t> fragmentOperandIndices;
|
||||||
|
SmallVector<int64_t> fragmentSourceOffsets;
|
||||||
|
SmallVector<int64_t> fragmentStrides;
|
||||||
|
|
||||||
|
if (parser.parseKeyword("fragments")
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)
|
||||||
|
|| parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout)
|
||||||
|
|| parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout)
|
||||||
|
|| parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets)
|
||||||
|
|| parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes)
|
||||||
|
|| parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("operandIndices"))
|
||||||
|
&& parseCompressedIntegerList(parser, fragmentOperandIndices))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("sourceOffsets"))
|
||||||
|
&& parseCompressedIntegerList(parser, fragmentSourceOffsets))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parser.parseType(outputType))
|
||||||
|
return failure();
|
||||||
|
if (operands.empty())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand");
|
||||||
|
if (operands.size() != operandTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match");
|
||||||
|
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
result.addAttribute("logicalLayout", logicalLayout);
|
||||||
|
result.addAttribute("physicalLayout", physicalLayout);
|
||||||
|
result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets));
|
||||||
|
result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes));
|
||||||
|
result.addAttribute("indexMap", indexMap);
|
||||||
|
if (mode)
|
||||||
|
result.addAttribute("mode", mode);
|
||||||
|
if (!fragmentOperandIndices.empty())
|
||||||
|
result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices));
|
||||||
|
if (!fragmentSourceOffsets.empty())
|
||||||
|
result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets));
|
||||||
|
if (!fragmentStrides.empty())
|
||||||
|
result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides));
|
||||||
|
if (conflictPolicy)
|
||||||
|
result.addAttribute("conflictPolicy", conflictPolicy);
|
||||||
|
if (coveragePolicy)
|
||||||
|
result.addAttribute("coveragePolicy", coveragePolicy);
|
||||||
|
|
||||||
|
if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||||
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||||
|
|||||||
@@ -436,10 +436,10 @@ LogicalResult SpatReluPlanOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatReconciliatorOp::verify() {
|
LogicalResult SpatBlueprintOp::verify() {
|
||||||
auto modeAttr = getModeAttr();
|
auto modeAttr = getModeAttr();
|
||||||
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
|
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
|
||||||
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
|
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.blueprint")))
|
||||||
return failure();
|
return failure();
|
||||||
if (!isKnownLogicalLayout(getLogicalLayout()))
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
return emitError("requires a known logical layout");
|
return emitError("requires a known logical layout");
|
||||||
@@ -482,10 +482,10 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
if (failed(verifyBoundsOnly({})))
|
if (failed(verifyBoundsOnly({})))
|
||||||
return failure();
|
return failure();
|
||||||
if (!getFragments().empty())
|
if (!getFragments().empty())
|
||||||
return emitError("legacy reconciliator does not accept extra fragment operands");
|
return emitError("legacy blueprint does not accept extra fragment operands");
|
||||||
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|
||||||
|| getCoveragePolicyAttr())
|
|| getCoveragePolicyAttr())
|
||||||
return emitError("legacy reconciliator does not accept fragment assembly attributes");
|
return emitError("legacy blueprint does not accept fragment assembly attributes");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,11 +493,11 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
|
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
|
||||||
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
|
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
|
||||||
if (!operandIndicesAttr)
|
if (!operandIndicesAttr)
|
||||||
return emitError("fragment assembly reconciliator requires fragment operand indices");
|
return emitError("fragment assembly blueprint requires fragment operand indices");
|
||||||
if (!sourceOffsetsAttr)
|
if (!sourceOffsetsAttr)
|
||||||
return emitError("fragment assembly reconciliator requires fragment source offsets");
|
return emitError("fragment assembly blueprint requires fragment source offsets");
|
||||||
if (!stridesAttr)
|
if (!stridesAttr)
|
||||||
return emitError("fragment assembly reconciliator requires fragment strides");
|
return emitError("fragment assembly blueprint requires fragment strides");
|
||||||
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
|
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
|
||||||
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
|
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
|
||||||
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
|
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
|
||||||
@@ -506,11 +506,11 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
if (sourceOffsets.size() != operandIndices.size())
|
if (sourceOffsets.size() != operandIndices.size())
|
||||||
return emitError("fragment source offset count must match fragment operand index count");
|
return emitError("fragment source offset count must match fragment operand index count");
|
||||||
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
|
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
|
||||||
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
|
return emitError("fragment assembly blueprint requires conflict and coverage policies");
|
||||||
if (getConflictPolicy() != "disjoint")
|
if (getConflictPolicy() != "disjoint")
|
||||||
return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\"");
|
return emitError("fragment assembly blueprint currently supports only conflict_policy=\"disjoint\"");
|
||||||
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
|
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
|
||||||
return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\"");
|
return emitError("fragment assembly blueprint coverage_policy must be \"complete\" or \"partial\"");
|
||||||
|
|
||||||
SmallVector<Value> operands;
|
SmallVector<Value> operands;
|
||||||
operands.push_back(getInput());
|
operands.push_back(getInput());
|
||||||
@@ -518,7 +518,7 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
int64_t operandCount = static_cast<int64_t>(operands.size());
|
int64_t operandCount = static_cast<int64_t>(operands.size());
|
||||||
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
|
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
|
||||||
if (operandCount == 0)
|
if (operandCount == 0)
|
||||||
return emitError("fragment assembly reconciliator requires at least one operand");
|
return emitError("fragment assembly blueprint requires at least one operand");
|
||||||
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
|
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
|
||||||
return emitError("fragment assembly metadata count must match operand count * result rank");
|
return emitError("fragment assembly metadata count must match operand count * result rank");
|
||||||
if (failed(verifyBoundsOnly(strides)))
|
if (failed(verifyBoundsOnly(strides)))
|
||||||
@@ -544,9 +544,9 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
|
|
||||||
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
|
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
|
||||||
if (!operandType || !operandType.hasStaticShape())
|
if (!operandType || !operandType.hasStaticShape())
|
||||||
return emitError("fragment assembly reconciliator requires static ranked tensor operands");
|
return emitError("fragment assembly blueprint requires static ranked tensor operands");
|
||||||
if (operandType.getRank() != rank)
|
if (operandType.getRank() != rank)
|
||||||
return emitError("fragment assembly reconciliator requires operand/result rank match");
|
return emitError("fragment assembly blueprint requires operand/result rank match");
|
||||||
|
|
||||||
SmallVector<int64_t, 4> fragmentOffsets;
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
SmallVector<int64_t, 4> fragmentSizes;
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
@@ -583,14 +583,14 @@ LogicalResult SpatReconciliatorOp::verify() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (overlaps)
|
if (overlaps)
|
||||||
return emitError("fragment assembly reconciliator requires disjoint static slices");
|
return emitError("fragment assembly blueprint requires disjoint static slices");
|
||||||
}
|
}
|
||||||
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
|
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
|
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
|
||||||
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
|
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
|
||||||
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
|
return emitError("fragment assembly blueprint requires every operand to contribute at least one fragment");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getCoveragePolicy() == "complete") {
|
if (getCoveragePolicy() == "complete") {
|
||||||
|
|||||||
+137
-105
@@ -30,6 +30,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -308,7 +309,7 @@ struct PendingProjectedHostOutputFragment {
|
|||||||
Value originalOutput;
|
Value originalOutput;
|
||||||
ClassId sourceClass = 0;
|
ClassId sourceClass = 0;
|
||||||
ProducerKey producerKey;
|
ProducerKey producerKey;
|
||||||
Value publicationValue;
|
unsigned publicationResultIndex = 0;
|
||||||
int64_t sourceFragmentOrdinal = 0;
|
int64_t sourceFragmentOrdinal = 0;
|
||||||
int64_t sourceElementOffset = 0;
|
int64_t sourceElementOffset = 0;
|
||||||
SmallVector<int64_t, 4> offsets;
|
SmallVector<int64_t, 4> offsets;
|
||||||
@@ -1220,36 +1221,13 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
|||||||
return std::get<1>(*arg);
|
return std::get<1>(*arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
void refreshPendingProjectedHostOutputPublicationValues(MaterializerState& state,
|
FailureOr<unsigned> appendScalarPublicationResult(MaterializerState& state,
|
||||||
Operation* oldOwner,
|
MaterializedClass& materializedClass,
|
||||||
Operation* newOwner) {
|
Value payload,
|
||||||
if (!oldOwner || oldOwner == newOwner)
|
Location loc) {
|
||||||
return;
|
|
||||||
|
|
||||||
for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) {
|
|
||||||
auto publicationResult = dyn_cast_or_null<OpResult>(fragment.publicationValue);
|
|
||||||
if (!publicationResult || publicationResult.getOwner() != oldOwner)
|
|
||||||
publicationResult = OpResult();
|
|
||||||
else
|
|
||||||
fragment.publicationValue = newOwner->getResult(publicationResult.getResultNumber());
|
|
||||||
|
|
||||||
if (auto originalResult = dyn_cast_or_null<OpResult>(fragment.originalOutput); originalResult
|
|
||||||
&& originalResult.getOwner() == oldOwner) {
|
|
||||||
fragment.originalOutput = newOwner->getResult(originalResult.getResultNumber());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (fragment.producerKey.instance.op == oldOwner)
|
|
||||||
fragment.producerKey.instance.op = newOwner;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
|
|
||||||
MaterializedClass& materializedClass,
|
|
||||||
Value payload,
|
|
||||||
Location loc) {
|
|
||||||
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
||||||
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
||||||
return materializedClass.op->getResult(existing->second);
|
return existing->second;
|
||||||
|
|
||||||
auto compute = dyn_cast<SpatScheduledCompute>(materializedClass.op);
|
auto compute = dyn_cast<SpatScheduledCompute>(materializedClass.op);
|
||||||
if (!compute)
|
if (!compute)
|
||||||
@@ -1264,27 +1242,25 @@ FailureOr<Value> appendScalarPublicationResult(MaterializerState& state,
|
|||||||
if (failed(inserted))
|
if (failed(inserted))
|
||||||
return materializedClass.op->emitError("failed to append scalar publication result");
|
return materializedClass.op->emitError("failed to append scalar publication result");
|
||||||
|
|
||||||
Operation* oldOp = materializedClass.op;
|
|
||||||
auto [result, newCompute] = *inserted;
|
auto [result, newCompute] = *inserted;
|
||||||
materializedClass.op = newCompute.getOperation();
|
materializedClass.op = newCompute.getOperation();
|
||||||
materializedClass.body = &newCompute.getBody().front();
|
materializedClass.body = &newCompute.getBody().front();
|
||||||
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
|
|
||||||
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
||||||
|
|
||||||
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
|
auto yieldOp = dyn_cast<SpatYieldOp>(materializedClass.body->getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result");
|
return materializedClass.op->emitError("expected spat.yield terminator while appending scalar publication result");
|
||||||
state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); });
|
state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->insertOperands(yieldOp.getNumOperands(), payload); });
|
||||||
return result;
|
return result.getResultNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
|
FailureOr<unsigned> appendBatchPublicationResult(MaterializerState& state,
|
||||||
MaterializedClass& materializedClass,
|
MaterializedClass& materializedClass,
|
||||||
Value payload,
|
Value payload,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
auto existing = materializedClass.publicationOutputToResultIndex.find(payload);
|
||||||
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
if (existing != materializedClass.publicationOutputToResultIndex.end())
|
||||||
return materializedClass.op->getResult(existing->second);
|
return existing->second;
|
||||||
|
|
||||||
auto batch = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op);
|
auto batch = dyn_cast<SpatScheduledComputeBatch>(materializedClass.op);
|
||||||
if (!batch)
|
if (!batch)
|
||||||
@@ -1305,11 +1281,9 @@ FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
|
|||||||
if (failed(inserted))
|
if (failed(inserted))
|
||||||
return materializedClass.op->emitError("failed to append batch publication result");
|
return materializedClass.op->emitError("failed to append batch publication result");
|
||||||
|
|
||||||
Operation* oldOp = materializedClass.op;
|
|
||||||
auto [result, outputArg, newBatch] = *inserted;
|
auto [result, outputArg, newBatch] = *inserted;
|
||||||
materializedClass.op = newBatch.getOperation();
|
materializedClass.op = newBatch.getOperation();
|
||||||
materializedClass.body = &newBatch.getBody().front();
|
materializedClass.body = &newBatch.getBody().front();
|
||||||
refreshPendingProjectedHostOutputPublicationValues(state, oldOp, materializedClass.op);
|
|
||||||
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
materializedClass.publicationOutputToResultIndex[payload] = result.getResultNumber();
|
||||||
|
|
||||||
auto inParallelOp = dyn_cast<SpatInParallelOp>(materializedClass.body->getTerminator());
|
auto inParallelOp = dyn_cast<SpatInParallelOp>(materializedClass.body->getTerminator());
|
||||||
@@ -1330,7 +1304,7 @@ FailureOr<Value> appendBatchPublicationResult(MaterializerState& state,
|
|||||||
Value firstOffset =
|
Value firstOffset =
|
||||||
scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc);
|
scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc);
|
||||||
createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset);
|
createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset);
|
||||||
return result;
|
return result.getResultNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -1563,7 +1537,7 @@ void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value val
|
|||||||
void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) {
|
void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) {
|
||||||
Block& body = *targetClass.body;
|
Block& body = *targetClass.body;
|
||||||
diagnostic.attachNote(targetClass.op->getLoc())
|
diagnostic.attachNote(targetClass.op->getLoc())
|
||||||
<< "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName()
|
<< "target class " << targetClass.id << " op '" << targetClass.op->getName()
|
||||||
<< "' body has " << body.getNumArguments() << " block arguments and "
|
<< "' body has " << body.getNumArguments() << " block arguments and "
|
||||||
<< std::distance(body.begin(), body.end()) << " top-level operations";
|
<< std::distance(body.begin(), body.end()) << " top-level operations";
|
||||||
}
|
}
|
||||||
@@ -1687,7 +1661,7 @@ FailureOr<Value> rematerializeIndexValueInClass(MaterializerState& state,
|
|||||||
|
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
||||||
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
||||||
"RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body");
|
"cannot rematerialize external block argument in materialized class body");
|
||||||
diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType()
|
diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType()
|
||||||
<< " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'";
|
<< " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'";
|
||||||
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
|
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
|
||||||
@@ -1709,16 +1683,16 @@ FailureOr<Value> rematerializeIndexValueInClass(MaterializerState& state,
|
|||||||
if (mapperHadOriginalValue && mappedOriginalValue != value)
|
if (mapperHadOriginalValue && mappedOriginalValue != value)
|
||||||
attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value");
|
attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value");
|
||||||
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
|
if (Operation* owner = blockArg.getOwner()->getParentOp()) {
|
||||||
attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op");
|
attachMaterializerOperationPrintNote(diagnostic, owner, "external block argument owner op");
|
||||||
attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain");
|
attachMaterializerParentChainNote(diagnostic, owner, "external block argument owner parent chain");
|
||||||
}
|
}
|
||||||
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op");
|
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op");
|
||||||
attachMaterializedClassBodySummary(diagnostic, targetClass);
|
attachMaterializedClassBodySummary(diagnostic, targetClass);
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
InFlightDiagnostic diagnostic =
|
InFlightDiagnostic diagnostic =
|
||||||
targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body");
|
targetClass.op->emitError("cannot rematerialize external index value in materialized class body");
|
||||||
diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='"
|
diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='"
|
||||||
<< targetClass.op->getName() << "'";
|
<< targetClass.op->getName() << "'";
|
||||||
attachMaterializerValueOriginNote(diagnostic, originalValue, "original value");
|
attachMaterializerValueOriginNote(diagnostic, originalValue, "original value");
|
||||||
@@ -1793,8 +1767,12 @@ FailureOr<Value> rematerializeTensorValueInClass(MaterializerState& state,
|
|||||||
strides.push_back(*localized);
|
strides.push_back(*localized);
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides)
|
auto resultType = dyn_cast<RankedTensorType>(extractSlice.getResult().getType());
|
||||||
.getResult();
|
if (!resultType)
|
||||||
|
return anchor->emitError("expected ranked tensor extract_slice while rematerializing tensor capture");
|
||||||
|
|
||||||
|
return extractStaticSliceOrIdentity(
|
||||||
|
state.rewriter, anchor->getLoc(), *localizedSource, resultType, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto collapseShape = value.getDefiningOp<tensor::CollapseShapeOp>()) {
|
if (auto collapseShape = value.getDefiningOp<tensor::CollapseShapeOp>()) {
|
||||||
@@ -2108,8 +2086,10 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in
|
|||||||
if (dim0Size == 1)
|
if (dim0Size == 1)
|
||||||
return index;
|
return index;
|
||||||
|
|
||||||
Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size);
|
MLIRContext* context = state.func.getContext();
|
||||||
return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult();
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0 * dim0Size);
|
||||||
|
return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, anchor);
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> scaleIndexByDim0SizeInClass(MaterializerState& state,
|
FailureOr<Value> scaleIndexByDim0SizeInClass(MaterializerState& state,
|
||||||
@@ -2123,8 +2103,7 @@ FailureOr<Value> scaleIndexByDim0SizeInClass(MaterializerState& state,
|
|||||||
if (dim0Size == 1)
|
if (dim0Size == 1)
|
||||||
return *localizedIndex;
|
return *localizedIndex;
|
||||||
|
|
||||||
Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size);
|
return scaleIndexByDim0Size(state, targetClass.op, *localizedIndex, dim0Size, loc);
|
||||||
return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) {
|
bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) {
|
||||||
@@ -3677,10 +3656,13 @@ FailureOr<Value> buildProjectedPackedPayload(MaterializerState& state,
|
|||||||
ValueRange {init},
|
ValueRange {init},
|
||||||
[&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
[&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||||
Value acc = iterArgs.front();
|
Value acc = iterArgs.front();
|
||||||
Value payloadFragmentCount =
|
MLIRContext* context = state.func.getContext();
|
||||||
getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult();
|
AffineExpr d1 = getAffineDimExpr(1, context);
|
||||||
Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult();
|
AffineMap flatIndexMap =
|
||||||
|
AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * descriptor.layout.payloadFragmentCount + d1);
|
||||||
|
Value flatIndex = createOrFoldAffineApply(
|
||||||
|
state.rewriter, loc, flatIndexMap, ValueRange {*localizedMessageIndex, fragmentIndex}, targetClass.op);
|
||||||
|
|
||||||
FailureOr<SmallVector<OpFoldResult, 4>> fragmentOffsets =
|
FailureOr<SmallVector<OpFoldResult, 4>> fragmentOffsets =
|
||||||
buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc);
|
buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc);
|
||||||
@@ -5618,8 +5600,8 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> publicationResult = appendScalarPublicationResult(state, sourceClass, packed, loc);
|
FailureOr<unsigned> publicationResultIndex = appendScalarPublicationResult(state, sourceClass, packed, loc);
|
||||||
if (failed(publicationResult))
|
if (failed(publicationResultIndex))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t fragmentElementCount = fragmentType.getNumElements();
|
int64_t fragmentElementCount = fragmentType.getNumElements();
|
||||||
@@ -5657,7 +5639,7 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
|||||||
originalOutput,
|
originalOutput,
|
||||||
sourceClass.id,
|
sourceClass.id,
|
||||||
ProducerKey {peer, resultIndex},
|
ProducerKey {peer, resultIndex},
|
||||||
*publicationResult,
|
*publicationResultIndex,
|
||||||
static_cast<int64_t>(runIndex),
|
static_cast<int64_t>(runIndex),
|
||||||
static_cast<int64_t>(runIndex) * fragmentElementCount,
|
static_cast<int64_t>(runIndex) * fragmentElementCount,
|
||||||
SmallVector<int64_t, 4>(*offsets),
|
SmallVector<int64_t, 4>(*offsets),
|
||||||
@@ -5711,8 +5693,8 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
|||||||
if (fragmentType == originalOutput.getType())
|
if (fragmentType == originalOutput.getType())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
FailureOr<Value> publicationResult = appendBatchPublicationResult(state, sourceClass, packed, loc);
|
FailureOr<unsigned> publicationResultIndex = appendBatchPublicationResult(state, sourceClass, packed, loc);
|
||||||
if (failed(publicationResult))
|
if (failed(publicationResultIndex))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (packedType != fragmentType) {
|
if (packedType != fragmentType) {
|
||||||
@@ -5764,7 +5746,7 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
|||||||
originalOutput,
|
originalOutput,
|
||||||
sourceClass.id,
|
sourceClass.id,
|
||||||
key,
|
key,
|
||||||
*publicationResult,
|
*publicationResultIndex,
|
||||||
static_cast<int64_t>(fragmentIndex),
|
static_cast<int64_t>(fragmentIndex),
|
||||||
static_cast<int64_t>(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload,
|
static_cast<int64_t>(*publishedLaneIndex) * payloadElementCount + localFragmentOffsetWithinPublishedPayload,
|
||||||
SmallVector<int64_t, 4>(*offsets),
|
SmallVector<int64_t, 4>(*offsets),
|
||||||
@@ -5787,18 +5769,26 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
|
|
||||||
SmallVector<Value, 8> outputs;
|
SmallVector<Value, 8> outputs;
|
||||||
outputs.reserve(byOutput.size());
|
outputs.reserve(byOutput.size());
|
||||||
for (const auto& entry : byOutput)
|
|
||||||
outputs.push_back(entry.first);
|
|
||||||
llvm::sort(outputs, [](Value lhs, Value rhs) {
|
|
||||||
return reinterpret_cast<uintptr_t>(lhs.getAsOpaquePointer())
|
|
||||||
< reinterpret_cast<uintptr_t>(rhs.getAsOpaquePointer());
|
|
||||||
});
|
|
||||||
|
|
||||||
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
auto returnOp = dyn_cast<func::ReturnOp>(state.func.getBody().front().getTerminator());
|
||||||
if (!returnOp)
|
if (!returnOp)
|
||||||
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
return state.func.emitError("expected func.return terminator while finalizing projected host output fragments");
|
||||||
|
|
||||||
|
DenseSet<Value> seenOutputs;
|
||||||
|
for (Value returned : returnOp.getOperands()) {
|
||||||
|
if (!byOutput.contains(returned) || !seenOutputs.insert(returned).second)
|
||||||
|
continue;
|
||||||
|
outputs.push_back(returned);
|
||||||
|
}
|
||||||
|
if (outputs.size() != byOutput.size())
|
||||||
|
return state.func.emitError("projected host output fragments must be keyed by returned logical host outputs");
|
||||||
|
|
||||||
for (Value originalOutput : outputs) {
|
for (Value originalOutput : outputs) {
|
||||||
|
if (isa_and_present<SpatScheduledCompute, SpatScheduledComputeBatch>(originalOutput.getDefiningOp())) {
|
||||||
|
return state.func.emitError(
|
||||||
|
"projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result");
|
||||||
|
}
|
||||||
|
|
||||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return state.func.emitError("projected host output must have static ranked tensor type");
|
return state.func.emitError("projected host output must have static ranked tensor type");
|
||||||
@@ -5806,13 +5796,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
SmallVector<PendingProjectedHostOutputFragment*, 16>& fragments = byOutput[originalOutput];
|
||||||
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
|
llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs,
|
||||||
const PendingProjectedHostOutputFragment* rhs) {
|
const PendingProjectedHostOutputFragment* rhs) {
|
||||||
if (lhs->publicationValue != rhs->publicationValue)
|
|
||||||
return reinterpret_cast<uintptr_t>(lhs->publicationValue.getAsOpaquePointer())
|
|
||||||
< reinterpret_cast<uintptr_t>(rhs->publicationValue.getAsOpaquePointer());
|
|
||||||
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
|
||||||
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
|
||||||
if (lhs->sourceClass != rhs->sourceClass)
|
if (lhs->sourceClass != rhs->sourceClass)
|
||||||
return lhs->sourceClass < rhs->sourceClass;
|
return lhs->sourceClass < rhs->sourceClass;
|
||||||
|
if (lhs->publicationResultIndex != rhs->publicationResultIndex)
|
||||||
|
return lhs->publicationResultIndex < rhs->publicationResultIndex;
|
||||||
|
if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal)
|
||||||
|
return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal;
|
||||||
return std::lexicographical_compare(lhs->offsets.begin(),
|
return std::lexicographical_compare(lhs->offsets.begin(),
|
||||||
lhs->offsets.end(),
|
lhs->offsets.end(),
|
||||||
rhs->offsets.begin(),
|
rhs->offsets.begin(),
|
||||||
@@ -5821,7 +5810,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
|
|
||||||
state.rewriter.setInsertionPoint(returnOp);
|
state.rewriter.setInsertionPoint(returnOp);
|
||||||
Location loc = fragments.front()->loc;
|
Location loc = fragments.front()->loc;
|
||||||
SmallVector<Value, 16> reconciliatorOperands;
|
SmallVector<Value, 16> blueprintOperands;
|
||||||
SmallVector<int64_t, 16> fragmentOperandIndices;
|
SmallVector<int64_t, 16> fragmentOperandIndices;
|
||||||
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
SmallVector<int64_t, 16> fragmentSourceOffsets;
|
||||||
SmallVector<int64_t, 64> flatOffsets;
|
SmallVector<int64_t, 64> flatOffsets;
|
||||||
@@ -5830,12 +5819,23 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||||
|
|
||||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||||
Value operand = fragmentRecord->publicationValue;
|
if (fragmentRecord->sourceClass >= state.classes.size())
|
||||||
|
return state.func.emitError("projected host output fragment references an invalid source class");
|
||||||
|
|
||||||
|
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||||
|
if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) {
|
||||||
|
return sourceClass.op->emitError("projected host output fragment references an invalid publication result")
|
||||||
|
<< " sourceClass=" << sourceClass.id
|
||||||
|
<< " resultIndex=" << fragmentRecord->publicationResultIndex
|
||||||
|
<< " resultCount=" << sourceClass.op->getNumResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex);
|
||||||
|
|
||||||
auto [operandIt, inserted] =
|
auto [operandIt, inserted] =
|
||||||
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(reconciliatorOperands.size()));
|
operandIndicesByValue.try_emplace(operand, static_cast<int64_t>(blueprintOperands.size()));
|
||||||
if (inserted)
|
if (inserted)
|
||||||
reconciliatorOperands.push_back(operand);
|
blueprintOperands.push_back(operand);
|
||||||
fragmentOperandIndices.push_back(operandIt->second);
|
fragmentOperandIndices.push_back(operandIt->second);
|
||||||
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
fragmentSourceOffsets.push_back(fragmentRecord->sourceElementOffset);
|
||||||
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
llvm::append_range(flatOffsets, fragmentRecord->offsets);
|
||||||
@@ -5847,12 +5847,12 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
return state.func.emitError("projected host output assembly requires static ranked tensor operands");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (reconciliatorOperands.empty())
|
if (blueprintOperands.empty())
|
||||||
return state.func.emitError("missing projected host output fragments");
|
return state.func.emitError("missing projected host output fragments");
|
||||||
|
|
||||||
Value input = reconciliatorOperands.front();
|
Value input = blueprintOperands.front();
|
||||||
ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front();
|
ValueRange extraFragments = ValueRange(blueprintOperands).drop_front();
|
||||||
auto reconciliator = spatial::SpatReconciliatorOp::create(
|
auto blueprint = spatial::SpatBlueprintOp::create(
|
||||||
state.rewriter,
|
state.rewriter,
|
||||||
loc,
|
loc,
|
||||||
resultType,
|
resultType,
|
||||||
@@ -5870,7 +5870,7 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
state.rewriter.getStringAttr("disjoint"),
|
state.rewriter.getStringAttr("disjoint"),
|
||||||
state.rewriter.getStringAttr("complete"));
|
state.rewriter.getStringAttr("complete"));
|
||||||
|
|
||||||
state.hostReplacements[originalOutput] = reconciliator.getOutput();
|
state.hostReplacements[originalOutput] = blueprint.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@@ -6284,6 +6284,32 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state,
|
|||||||
mapper.map(operand, *localized);
|
mapper.map(operand, *localized);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op)) {
|
||||||
|
auto remapFoldResult = [&](OpFoldResult value) -> OpFoldResult {
|
||||||
|
if (auto mappedValue = dyn_cast_if_present<Value>(value))
|
||||||
|
return mapper.lookupOrDefault(mappedValue);
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult, 4> offsets;
|
||||||
|
SmallVector<OpFoldResult, 4> sizes;
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
offsets.reserve(extract.getMixedOffsets().size());
|
||||||
|
sizes.reserve(extract.getMixedSizes().size());
|
||||||
|
strides.reserve(extract.getMixedStrides().size());
|
||||||
|
|
||||||
|
llvm::append_range(offsets, llvm::map_range(extract.getMixedOffsets(), remapFoldResult));
|
||||||
|
llvm::append_range(sizes, llvm::map_range(extract.getMixedSizes(), remapFoldResult));
|
||||||
|
llvm::append_range(strides, llvm::map_range(extract.getMixedStrides(), remapFoldResult));
|
||||||
|
|
||||||
|
auto resultType = cast<RankedTensorType>(extract.getType());
|
||||||
|
Value localizedSource = mapper.lookupOrDefault(extract.getSource());
|
||||||
|
Value localizedExtract = extractStaticSliceOrIdentity(
|
||||||
|
state.rewriter, extract.getLoc(), localizedSource, resultType, offsets, sizes, strides);
|
||||||
|
mapper.map(extract.getResult(), localizedExtract);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
Operation* cloned = state.rewriter.clone(op, mapper);
|
Operation* cloned = state.rewriter.clone(op, mapper);
|
||||||
if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper)))
|
if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper)))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -6350,18 +6376,20 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
|
|||||||
if (failed(localizedIv))
|
if (failed(localizedIv))
|
||||||
return failure();
|
return failure();
|
||||||
Value iv = *localizedIv;
|
Value iv = *localizedIv;
|
||||||
Value lowerBound =
|
MLIRContext* context = state.func.getContext();
|
||||||
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]);
|
AffineMap normalizedMap =
|
||||||
Value tripCount =
|
AffineMap::get(/*dimCount=*/1,
|
||||||
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]);
|
/*symbolCount=*/0,
|
||||||
|
(d0 - replacement.layout.loopLowerBounds[index]).floorDiv(replacement.layout.loopSteps[index]));
|
||||||
|
Value normalized =
|
||||||
|
createOrFoldAffineApply(state.rewriter, extract.getLoc(), normalizedMap, ValueRange {iv}, targetClass.op);
|
||||||
|
|
||||||
Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult();
|
AffineExpr d1 = getAffineDimExpr(1, context);
|
||||||
if (replacement.layout.loopSteps[index] != 1)
|
AffineMap linearizedMap = AffineMap::get(
|
||||||
normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult();
|
/*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.loopTripCounts[index] + d1);
|
||||||
linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult();
|
linearizedIndex = createOrFoldAffineApply(
|
||||||
linearizedIndex =
|
state.rewriter, extract.getLoc(), linearizedMap, ValueRange {linearizedIndex, normalized}, targetClass.op);
|
||||||
arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult();
|
|
||||||
}
|
}
|
||||||
return linearizedIndex;
|
return linearizedIndex;
|
||||||
};
|
};
|
||||||
@@ -6386,12 +6414,16 @@ FailureOr<Value> materializeProjectedExtractReplacement(MaterializerState& state
|
|||||||
if (failed(localProjectionSlotIndex))
|
if (failed(localProjectionSlotIndex))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value fragmentsPerLogicalSlot =
|
MLIRContext* context = state.func.getContext();
|
||||||
getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
Value base =
|
AffineExpr d1 = getAffineDimExpr(1, context);
|
||||||
arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot)
|
AffineMap packedIndexMap = AffineMap::get(
|
||||||
.getResult();
|
/*dimCount=*/2, /*symbolCount=*/0, d0 * replacement.layout.fragmentsPerLogicalSlot + d1);
|
||||||
return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult();
|
return createOrFoldAffineApply(state.rewriter,
|
||||||
|
extract.getLoc(),
|
||||||
|
packedIndexMap,
|
||||||
|
ValueRange {*localProjectionSlotIndex, intraSlotFragmentIndex},
|
||||||
|
targetClass.op);
|
||||||
};
|
};
|
||||||
|
|
||||||
FailureOr<Value> packedFragmentIndex = computeProjectedPayloadFragmentIndex();
|
FailureOr<Value> packedFragmentIndex = computeProjectedPayloadFragmentIndex();
|
||||||
@@ -6445,18 +6477,18 @@ LogicalResult localizeCapturesInOperationTree(MaterializerState& state,
|
|||||||
localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper);
|
localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper);
|
||||||
if (failed(localized)) {
|
if (failed(localized)) {
|
||||||
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
||||||
"RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand");
|
"failed to localize cloned scheduled-body operand");
|
||||||
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
|
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
|
||||||
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
|
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
|
||||||
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
|
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
|
||||||
<< "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass)
|
<< "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass)
|
||||||
<< "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\"";
|
<< "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\"";
|
||||||
diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation";
|
diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation";
|
||||||
attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR");
|
attachMaterializerOperationPrintNote(diagnostic, nestedOp, "offending nested operation IR");
|
||||||
attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands");
|
attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "offending nested operation operands");
|
||||||
attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain");
|
attachMaterializerParentChainNote(diagnostic, nestedOp, "offending nested operation parent chain");
|
||||||
attachMaterializerValueOriginNote(diagnostic, current, "offending operand");
|
attachMaterializerValueOriginNote(diagnostic, current, "offending operand");
|
||||||
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op");
|
attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "target materialized op");
|
||||||
attachMaterializedClassBodySummary(diagnostic, targetClass);
|
attachMaterializedClassBodySummary(diagnostic, targetClass);
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
}
|
}
|
||||||
@@ -6505,7 +6537,7 @@ LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, Materia
|
|||||||
"final scheduled body capture localization found an unsupported external non-tensor operand");
|
"final scheduled body capture localization found an unsupported external non-tensor operand");
|
||||||
if (failed(localized)) {
|
if (failed(localized)) {
|
||||||
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
InFlightDiagnostic diagnostic = targetClass.op->emitError(
|
||||||
"RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand");
|
"failed to localize final scheduled-body operand");
|
||||||
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
|
diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName()
|
||||||
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
|
<< "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType()
|
||||||
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
|
<< " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp))
|
||||||
|
|||||||
-9510
File diff suppressed because it is too large
Load Diff
-7548
File diff suppressed because it is too large
Load Diff
-128
@@ -1,128 +0,0 @@
|
|||||||
--- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000
|
|
||||||
+++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000
|
|
||||||
@@ -4112,104 +4112,8 @@
|
|
||||||
Value originalOutput,
|
|
||||||
Location loc);
|
|
||||||
|
|
||||||
-FailureOr<SmallVector<OpFoldResult, 4>> rematerializeProjectionIndexListForBatchHostOutput(
|
|
||||||
- MaterializerState& state,
|
|
||||||
- MaterializedClass& sourceClass,
|
|
||||||
- ArrayRef<OpFoldResult> values,
|
|
||||||
- IRMapping& mapper,
|
|
||||||
- Location loc) {
|
|
||||||
- SmallVector<OpFoldResult, 4> localized;
|
|
||||||
- localized.reserve(values.size());
|
|
||||||
- for (OpFoldResult value : values) {
|
|
||||||
- FailureOr<OpFoldResult> remapped =
|
|
||||||
- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper);
|
|
||||||
- if (failed(remapped))
|
|
||||||
- return failure();
|
|
||||||
- localized.push_back(*remapped);
|
|
||||||
- }
|
|
||||||
- return localized;
|
|
||||||
-}
|
|
||||||
-
|
|
||||||
-LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state,
|
|
||||||
- MaterializedClass& sourceClass,
|
|
||||||
- Value originalOutput,
|
|
||||||
- Value payload,
|
|
||||||
- Value destination,
|
|
||||||
- ArrayRef<ProducerKey> keys,
|
|
||||||
- Location loc) {
|
|
||||||
- auto originalResult = dyn_cast<OpResult>(originalOutput);
|
|
||||||
- if (!originalResult)
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- auto sourceBatch = dyn_cast_or_null<SpatComputeBatch>(originalResult.getOwner());
|
|
||||||
- if (!sourceBatch || sourceBatch.getNumResults() == 0)
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- FailureOr<tensor::ParallelInsertSliceOp> projection =
|
|
||||||
- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber());
|
|
||||||
- if (failed(projection))
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- auto sourceLaneArg = sourceBatch.getLaneArgument();
|
|
||||||
- if (!sourceLaneArg)
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- auto materializedBatch = dyn_cast<SpatScheduledComputeBatch>(sourceClass.op);
|
|
||||||
- if (!materializedBatch)
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- auto materializedLaneArg = materializedBatch.getLaneArgument();
|
|
||||||
- if (!materializedLaneArg)
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- if (keys.size() != sourceClass.cpus.size())
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- SmallVector<int64_t, 8> logicalLanes;
|
|
||||||
- logicalLanes.reserve(keys.size());
|
|
||||||
- for (ProducerKey key : keys) {
|
|
||||||
- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber())
|
|
||||||
- return failure();
|
|
||||||
- logicalLanes.push_back(key.instance.laneStart);
|
|
||||||
- }
|
|
||||||
-
|
|
||||||
- IRMapping mapper;
|
|
||||||
- Value logicalLane = createIndexedIndexValue(state,
|
|
||||||
- sourceClass.op,
|
|
||||||
- ArrayRef<int64_t>(logicalLanes),
|
|
||||||
- *materializedLaneArg,
|
|
||||||
- loc,
|
|
||||||
- static_cast<int64_t>(sourceClass.cpus.size()),
|
|
||||||
- /*allowExhaustiveTiledSearch=*/false);
|
|
||||||
- mapper.map(*sourceLaneArg, logicalLane);
|
|
||||||
-
|
|
||||||
- FailureOr<SmallVector<OpFoldResult, 4>> offsets =
|
|
||||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
|
||||||
- state, sourceClass, projection->getMixedOffsets(), mapper, loc);
|
|
||||||
- if (failed(offsets))
|
|
||||||
- return failure();
|
|
||||||
- FailureOr<SmallVector<OpFoldResult, 4>> sizes =
|
|
||||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
|
||||||
- state, sourceClass, projection->getMixedSizes(), mapper, loc);
|
|
||||||
- if (failed(sizes))
|
|
||||||
- return failure();
|
|
||||||
- FailureOr<SmallVector<OpFoldResult, 4>> strides =
|
|
||||||
- rematerializeProjectionIndexListForBatchHostOutput(
|
|
||||||
- state, sourceClass, projection->getMixedStrides(), mapper, loc);
|
|
||||||
- if (failed(strides))
|
|
||||||
- return failure();
|
|
||||||
-
|
|
||||||
- tensor::ParallelInsertSliceOp::create(
|
|
||||||
- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides);
|
|
||||||
- return success();
|
|
||||||
-}
|
|
||||||
-
|
|
||||||
LogicalResult
|
|
||||||
-setHostOutputValue(MaterializerState& state,
|
|
||||||
- MaterializedClass& sourceClass,
|
|
||||||
- Value originalOutput,
|
|
||||||
- Value payload,
|
|
||||||
- ArrayRef<ProducerKey> keys = {}) {
|
|
||||||
+setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
|
|
||||||
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
|
|
||||||
if (resultIt == sourceClass.hostOutputToResultIndex.end())
|
|
||||||
return sourceClass.op->emitError("missing host result slot for materialized output")
|
|
||||||
@@ -4253,10 +4157,6 @@
|
|
||||||
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
|
||||||
|
|
||||||
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
||||||
- if (succeeded(createProjectionAwareBatchHostInsert(
|
|
||||||
- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc())))
|
|
||||||
- return success();
|
|
||||||
-
|
|
||||||
createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
@@ -4276,7 +4176,7 @@
|
|
||||||
|
|
||||||
MaterializedClass& ownerClass = state.classes[ownerIt->second];
|
|
||||||
if (sourceClass.id == ownerClass.id)
|
|
||||||
- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys);
|
|
||||||
+ return setHostOutputValue(state, ownerClass, originalOutput, payload);
|
|
||||||
|
|
||||||
// Keep the old deadlock-free communication discipline: only scalar-to-scalar
|
|
||||||
// host-owner forwarding is introduced here. Batch host publication remains on
|
|
||||||
@@ -354,13 +354,13 @@ public:
|
|||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
|
func.emitOpError("logical Spatial graph verification failed at the start of MergeComputeNodes");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
mergeTriviallyConnectedComputes(func);
|
mergeTriviallyConnectedComputes(func);
|
||||||
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
|
func.emitOpError("logical Spatial graph verification failed after trivial merge simplification");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -378,7 +378,7 @@ public:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (failed(verifyScheduledSpatialInvariants(func))) {
|
if (failed(verifyScheduledSpatialInvariants(func))) {
|
||||||
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
|
func.emitOpError("scheduled Spatial verification failed after merge materialization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user