This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
SpatialLayoutPlanningPass.cpp
LowerSpatialPlansPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
@@ -9,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
}
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute<NumInputs>(
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphComputeBatch(
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
@@ -0,0 +1,409 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
struct RowStripPhysicalValue {
Value physicalValue;
RankedTensorType logicalType;
SmallVector<int64_t, 16> fragmentOffsets;
SmallVector<int64_t, 16> fragmentSizes;
std::string indexMap;
};
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
Value value) {
auto it = rowStripValues.find(value);
if (it == rowStripValues.end())
return failure();
return it->second;
}
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
Value physicalValue) {
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!logicalType)
return reconciliator.emitOpError("requires ranked logical output type"), failure();
RowStripPhysicalValue value;
value.physicalValue = physicalValue;
value.logicalType = logicalType;
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
value.indexMap = reconciliator.getIndexMap().str();
return value;
}
static FailureOr<Value>
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
auto computeOp =
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
return computeOp.getResult(0);
}
static FailureOr<Value>
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
return failure();
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
return failure();
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
return failure();
const int64_t rank = rowStripValue.logicalType.getRank();
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
const int64_t packedWidth = packedType.getDimSize(1);
const int64_t packedChannels = packedType.getDimSize(2);
if (fragmentCount != packedType.getDimSize(0))
return failure();
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
return failure();
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
return failure();
}
auto packedSliceType =
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto expandedType =
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto logicalFragmentType =
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {rowStripValue.logicalType},
fragmentCount,
{},
ValueRange {rowStripValue.physicalValue},
[&](detail::SpatComputeBatchBodyArgs args) {
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> packedSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
Value packedSlice = tensor::ExtractSliceOp::create(
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2},
{3}
});
Value transposeInit =
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
Value logicalFragment =
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
.getResult()[0];
SmallVector<OpFoldResult> logicalOffsets {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedWidth)};
createParallelInsertSliceIntoBatchOutput(rewriter,
loc,
logicalFragment,
args.outputs.front(),
logicalOffsets,
logicalSizes,
getUnitStrides(rewriter, 4));
return success();
});
if (failed(batchOp))
return failure();
return batchOp->getResult(0);
}
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
StringRef getArgument() const override { return "lower-spatial-plans"; }
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
PatternRewriter rewriter(ctx);
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
return true;
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
signalPassFailure();
return false;
};
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
return;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
planOp,
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
/*emitRowStripLayout=*/true,
rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
if (failed(rowStripValue)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *rowStripValue;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered =
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected Spatial Conv plan");
signalPassFailure();
return;
}
rewriter.replaceOp(planOp, *lowered);
continue;
}
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (outputReconciliator == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
if (failed(output)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *output;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
auto computeOp = createSpatCompute<1>(
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
rewriter.replaceOp(planOp, computeOp.getResults());
continue;
}
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(materializeOp, materializeOp.getInput());
continue;
}
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
if (failed(rowStripValue)) {
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
signalPassFailure();
return;
}
rewriter.setInsertionPoint(materializeOp);
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
if (failed(dense)) {
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
signalPassFailure();
return;
}
rewriter.replaceOp(materializeOp, *dense);
continue;
}
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
continue;
}
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
signalPassFailure();
return;
}
if (!eraseAfterLowering.contains(reconciliatorOp)) {
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
signalPassFailure();
return;
}
}
}
bool erasedAny = true;
while (erasedAny) {
erasedAny = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (!eraseAfterLowering.contains(&op))
continue;
if (!op.use_empty())
continue;
eraseAfterLowering.erase(&op);
rewriter.eraseOp(&op);
erasedAny = true;
}
}
if (!eraseAfterLowering.empty()) {
for (Operation& op : funcOp.getBody().front())
if (eraseAfterLowering.contains(&op))
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
signalPassFailure();
return;
}
ConversionTarget helperTarget(*ctx);
helperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
RewritePatternSet helperPatterns(ctx);
populateGemmPatterns(helperPatterns, ctx);
populateTransposePatterns(helperPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
FrozenRewritePatternSet nestedHelperPatterns([&] {
RewritePatternSet patterns(ctx);
populateGemmPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
return patterns;
}());
ConversionTarget nestedHelperTarget(*ctx);
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
SmallVector<Operation*> computeLikeOps;
funcOp.walk([&](Operation* op) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
computeLikeOps.push_back(op);
});
for (Operation* op : computeLikeOps) {
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
}
if (!verifyLogicalPhase("after nested helper conversions"))
return;
bool hasIllegalOps = false;
moduleOp.walk([&](Operation* op) {
if (isa<ONNXEntryPointOp>(op))
return;
if (isa<spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(op)
|| op->getDialect()->getNamespace() == "onnx") {
op->emitOpError("operation must not remain after LowerSpatialPlans");
hasIllegalOps = true;
}
});
if (hasIllegalOps)
signalPassFailure();
else
dumpModule(moduleOp, "spatial1_premerge");
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
return;
}
};
} // namespace
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
} // namespace onnx_mlir
@@ -18,6 +18,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ONNXToSpatialVerifier.hpp"
using namespace mlir;
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|| !materializers.empty()) {
return;
}
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {
@@ -1,4 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
@@ -13,6 +15,8 @@ namespace onnx_mlir {
namespace {
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
illegalOp->emitOpError()
<< kPhaseMarker
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
});
return;
}
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
return candidate && (&region == candidate || region.isAncestor(candidate));
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
bool isValueDefinedInsideRegion(Value value, Region& region) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
if (Operation* definingOp = value.getDefiningOp())
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
return false;
}
bool isLegalExternalCapture(Value value, Region& region) {
if (isValueDefinedInsideRegion(value, region))
return true;
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
template <typename ComputeOpTy>
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
Region& body = compute.getBody();
body.walk([&](Operation* nestedOp) {
for (OpOperand& operand : nestedOp->getOpOperands()) {
Value value = operand.get();
if (isLegalExternalCapture(value, body))
continue;
Operation* definingOp = value.getDefiningOp();
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag =
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
diag << " (type " << value.getType() << ")";
if (definingOp)
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (Operation* owner = blockArg.getOwner()->getParentOp())
diag.attachNote(owner->getLoc())
<< "external block argument belongs to " << owner->getName().getStringRef();
}
});
}
});
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
template <typename ComputeOpTy>
void verifyScheduledInputs(ComputeOpTy compute,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag = illegalOp->emitOpError()
<< kPhaseMarker << " " << kind << " input #" << inputIndex
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp,
spatial::SpatGraphCompute,
spatial::SpatGraphComputeBatch,
spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(&op)) {
continue;
}
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
});
continue;
}
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker
<< " explicit channel communication is not expected before merge materialization";
});
continue;
}
if (isCompileTimeOp(&op))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError()
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
});
}
}
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
});
}
});
}
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyLogicalTopLevelOps(funcOp, diagnostics);
checkWeightUseChains(funcOp, diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
return success(!diagnostics.hasFailure());
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyScheduledTopLevelOps(funcOp, diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
return success(!diagnostics.hasFailure());
}
@@ -6,6 +6,8 @@
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
auto reluPlan = spatial::SpatReluPlanOp::create(
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
rewriter.replaceOp(reluOp, reluPlan.getResult());
return success();
}
};
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
if (failed(laneCountAttr))
return failure();
auto newCompute = spatial::SpatComputeBatch::create(
auto newCompute = spatial::SpatGraphComputeBatch::create(
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include <optional>
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
mlir::FailureOr<mlir::Value>
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
std::optional<mlir::Value> rowStripInput,
bool emitRowStripLayout,
mlir::PatternRewriter& rewriter);
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
} // namespace onnx_mlir
@@ -0,0 +1,200 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kLogicalLayout = "nchw";
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
enum class SelectedLayout {
DenseNchw,
NchwRowStrip,
};
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
auto it = layouts.find(value);
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
}
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
return false;
}
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
for (Operation* user : value.getUsers()) {
if (usesSelectedRowStrip(user, layouts))
continue;
// Dense-only users must be materialized explicitly.
continue;
}
return true;
}
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
const int64_t channels = type.getDimSize(1);
const int64_t height = type.getDimSize(2);
const int64_t width = type.getDimSize(3);
offsets.reserve(height * 4);
sizes.reserve(height * 4);
for (int64_t row = 0; row < height; ++row) {
offsets.append({0, 0, row, 0});
sizes.append({1, channels, 1, width});
}
return {offsets, sizes};
}
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
if (inputLayout == SelectedLayout::NchwRowStrip)
return succeeded(canConsumeAndProduceRowStrip(convPlan));
return succeeded(canLowerConvPlanToRowStrip(convPlan));
}
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (!canSelectConvRowStrip(convPlan, layouts))
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
auto outputType = cast<RankedTensorType>(value.getType());
auto [offsets, sizes] = buildRowStripMetadata(outputType);
return spatial::SpatReconciliatorOp::create(rewriter,
value.getLoc(),
outputType,
value,
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getDenseI64ArrayAttr(offsets),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getStringAttr(kRowStripIndexMap));
}
static void materializeDenseUses(IRRewriter& rewriter,
Value layoutValue,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SmallVector<OpOperand*> denseUses;
for (OpOperand& use : layoutValue.getUses()) {
if (usesSelectedRowStrip(use.getOwner(), layouts))
continue;
denseUses.push_back(&use);
}
for (OpOperand* use : denseUses) {
Operation* owner = use->getOwner();
rewriter.setInsertionPoint(owner);
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
owner->getLoc(),
use->get().getType(),
use->get(),
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getStringAttr(kDenseLayout));
use->set(materialized.getResult());
}
}
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
StringRef getArgument() const override { return "spatial-layout-planning"; }
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
void runOnOperation() override {
auto entryFunc = getPimEntryFunc(getOperation());
if (failed(entryFunc)) {
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
llvm::DenseMap<Value, SelectedLayout> layouts;
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
if (layouts[convPlan.getResult()] != selected) {
layouts[convPlan.getResult()] = selected;
changed = true;
}
continue;
}
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
if (layouts[reluPlan.getResult()] != selected) {
layouts[reluPlan.getResult()] = selected;
changed = true;
}
continue;
}
}
}
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
Value producedValue;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
producedValue = convPlan.getResult();
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
producedValue = reluPlan.getResult();
else
continue;
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
continue;
rewriter.setInsertionPointAfter(&op);
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
} // namespace onnx_mlir
@@ -102,7 +102,7 @@ static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
@@ -171,7 +171,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
} // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument;
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg;
}
else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg;
}
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner);
}
@@ -55,7 +55,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
}
}
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
auto checkedCoreId =
@@ -66,7 +66,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
return *checkedCoreId;
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
@@ -104,13 +104,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
@@ -145,7 +145,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
Location loc = computeOp->getLoc();
@@ -10,6 +10,14 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
for (NamedAttribute attr : source->getAttrs()) {
StringRef name = attr.getName().strref();
if (name.starts_with("raptor."))
target->setAttr(attr.getName(), attr.getValue());
}
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
if (failed(sizeAttr))
return failure();
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
rewriter.eraseOp(op);
return success();
}
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
if (failed(sizeAttr))
return failure();
Value received = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
.getOutput();
auto receive = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
Value received = receive.getOutput();
rewriter.replaceOp(op, received);
return success();
}
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure();
}
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
else {
{
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
@@ -212,7 +212,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
@@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
}
@@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
@@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
markOpToRemove(computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
@@ -25,9 +25,11 @@
#include <cassert>
#include <utility>
#include "Common/IR/ShapeUtils.hpp"
#include "Common/IR/ConstantUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp"
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
.getOutput();
}
static bool isHostBackedMemRefValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return isa<memref::GetGlobalOp>(definingOp);
}
return false;
}
static bool isHostBackedTensorValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return false;
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
extractSliceOp.getMixedOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
return false;
}
value = extractSliceOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
return isHostBackedMemRefValue(toTensorOp.getBuffer());
return false;
}
return false;
}
static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType());
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr))
return failure();
if (isHostBackedTensorValue(vector)) {
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
}
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return;
}
func::FuncOp funcOp = *entryFunc;
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
funcOp.emitOpError(
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
signalPassFailure();
return;
}
IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return;
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
signalPassFailure();
return;
}
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
signalPassFailure();
return;
}
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
};
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
@@ -41,8 +41,11 @@ private:
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
lowerComputeOp(spatial::SpatScheduledCompute computeOp,
mlir::IRRewriter& rewriter,
mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult {
Handled,
@@ -51,7 +54,7 @@ private:
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
mlir::IRRewriter& rewriter);