DeadLock
This commit is contained in:
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Tensor/Split.cpp
|
||||
Patterns/Tensor/Transpose.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
SpatialLayoutPlanningPass.cpp
|
||||
LowerSpatialPlansPass.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
|
||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
|
||||
if (tensors.size() == 1)
|
||||
return tensors[0];
|
||||
|
||||
|
||||
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
|
||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
|
||||
/// the body callback reports failure.
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
|
||||
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto createSpatGraphComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||
|
||||
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
||||
if (mlir::failed(laneCountAttr))
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||
|
||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(args);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
rewriter.eraseOp(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
return createSpatGraphCompute<NumInputs>(
|
||||
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
return createSpatGraphComputeBatch(
|
||||
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
|
||||
}
|
||||
|
||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||
assert("Not a vector" && isVectorShape(shape));
|
||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
|
||||
}
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>>
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||
|
||||
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||
/// current PIM target geometry.
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||
|
||||
struct RowStripPhysicalValue {
|
||||
Value physicalValue;
|
||||
RankedTensorType logicalType;
|
||||
SmallVector<int64_t, 16> fragmentOffsets;
|
||||
SmallVector<int64_t, 16> fragmentSizes;
|
||||
std::string indexMap;
|
||||
};
|
||||
|
||||
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
|
||||
Value value) {
|
||||
auto it = rowStripValues.find(value);
|
||||
if (it == rowStripValues.end())
|
||||
return failure();
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
|
||||
Value physicalValue) {
|
||||
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
if (!logicalType)
|
||||
return reconciliator.emitOpError("requires ranked logical output type"), failure();
|
||||
RowStripPhysicalValue value;
|
||||
value.physicalValue = physicalValue;
|
||||
value.logicalType = logicalType;
|
||||
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
|
||||
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
|
||||
value.indexMap = reconciliator.getIndexMap().str();
|
||||
return value;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
|
||||
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
|
||||
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
|
||||
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
|
||||
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
|
||||
return failure();
|
||||
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
|
||||
return failure();
|
||||
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
|
||||
return failure();
|
||||
|
||||
const int64_t rank = rowStripValue.logicalType.getRank();
|
||||
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
|
||||
const int64_t packedWidth = packedType.getDimSize(1);
|
||||
const int64_t packedChannels = packedType.getDimSize(2);
|
||||
if (fragmentCount != packedType.getDimSize(0))
|
||||
return failure();
|
||||
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|
||||
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|
||||
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|
||||
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
|
||||
return failure();
|
||||
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|
||||
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|
||||
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|
||||
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto packedSliceType =
|
||||
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||
auto expandedType =
|
||||
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||
auto logicalFragmentType =
|
||||
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {rowStripValue.logicalType},
|
||||
fragmentCount,
|
||||
{},
|
||||
ValueRange {rowStripValue.physicalValue},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> packedSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
|
||||
Value packedSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
|
||||
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2},
|
||||
{3}
|
||||
});
|
||||
Value transposeInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
|
||||
Value logicalFragment =
|
||||
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
|
||||
.getResult()[0];
|
||||
|
||||
SmallVector<OpFoldResult> logicalOffsets {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(packedChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(packedWidth)};
|
||||
createParallelInsertSliceIntoBatchOutput(rewriter,
|
||||
loc,
|
||||
logicalFragment,
|
||||
args.outputs.front(),
|
||||
logicalOffsets,
|
||||
logicalSizes,
|
||||
getUnitStrides(rewriter, 4));
|
||||
return success();
|
||||
});
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return batchOp->getResult(0);
|
||||
}
|
||||
|
||||
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
|
||||
|
||||
StringRef getArgument() const override { return "lower-spatial-plans"; }
|
||||
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
PatternRewriter rewriter(ctx);
|
||||
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
|
||||
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
|
||||
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||
return true;
|
||||
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
|
||||
signalPassFailure();
|
||||
return false;
|
||||
};
|
||||
|
||||
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
|
||||
return;
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||
planOp,
|
||||
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
|
||||
/*emitRowStripLayout=*/true,
|
||||
rewriter);
|
||||
if (failed(lowered)) {
|
||||
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
|
||||
if (failed(rowStripValue)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *rowStripValue;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
continue;
|
||||
}
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
FailureOr<Value> lowered =
|
||||
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
|
||||
if (failed(lowered)) {
|
||||
planOp.emitOpError("failed to lower selected Spatial Conv plan");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rewriter.replaceOp(planOp, *lowered);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (outputReconciliator == planOp.getResult().getUsers().end()) {
|
||||
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
|
||||
if (failed(lowered)) {
|
||||
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
|
||||
if (failed(output)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *output;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
|
||||
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
|
||||
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||
});
|
||||
rewriter.replaceOp(planOp, computeOp.getResults());
|
||||
continue;
|
||||
}
|
||||
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
|
||||
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(materializeOp, materializeOp.getInput());
|
||||
continue;
|
||||
}
|
||||
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|
||||
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
|
||||
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||
if (failed(rowStripValue)) {
|
||||
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rewriter.setInsertionPoint(materializeOp);
|
||||
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
|
||||
if (failed(dense)) {
|
||||
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rewriter.replaceOp(materializeOp, *dense);
|
||||
continue;
|
||||
}
|
||||
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
|
||||
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
|
||||
continue;
|
||||
}
|
||||
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
|
||||
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (!eraseAfterLowering.contains(reconciliatorOp)) {
|
||||
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool erasedAny = true;
|
||||
while (erasedAny) {
|
||||
erasedAny = false;
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
if (!eraseAfterLowering.contains(&op))
|
||||
continue;
|
||||
if (!op.use_empty())
|
||||
continue;
|
||||
eraseAfterLowering.erase(&op);
|
||||
rewriter.eraseOp(&op);
|
||||
erasedAny = true;
|
||||
}
|
||||
}
|
||||
if (!eraseAfterLowering.empty()) {
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (eraseAfterLowering.contains(&op))
|
||||
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
ConversionTarget helperTarget(*ctx);
|
||||
helperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
tensor::TensorDialect,
|
||||
linalg::LinalgDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect,
|
||||
func::FuncDialect>();
|
||||
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||
|
||||
RewritePatternSet helperPatterns(ctx);
|
||||
populateGemmPatterns(helperPatterns, ctx);
|
||||
populateTransposePatterns(helperPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
|
||||
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
FrozenRewritePatternSet nestedHelperPatterns([&] {
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateTransposePatterns(patterns, ctx);
|
||||
return patterns;
|
||||
}());
|
||||
ConversionTarget nestedHelperTarget(*ctx);
|
||||
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
tensor::TensorDialect,
|
||||
linalg::LinalgDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect,
|
||||
func::FuncDialect>();
|
||||
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||
SmallVector<Operation*> computeLikeOps;
|
||||
funcOp.walk([&](Operation* op) {
|
||||
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
|
||||
computeLikeOps.push_back(op);
|
||||
});
|
||||
for (Operation* op : computeLikeOps) {
|
||||
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
|
||||
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!verifyLogicalPhase("after nested helper conversions"))
|
||||
return;
|
||||
bool hasIllegalOps = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (isa<ONNXEntryPointOp>(op))
|
||||
return;
|
||||
if (isa<spatial::SpatConv2DPlanOp,
|
||||
spatial::SpatReluPlanOp,
|
||||
spatial::SpatReconciliatorOp,
|
||||
spatial::SpatMaterializeLayoutOp>(op)
|
||||
|| op->getDialect()->getNamespace() == "onnx") {
|
||||
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||
hasIllegalOps = true;
|
||||
}
|
||||
});
|
||||
if (hasIllegalOps)
|
||||
signalPassFailure();
|
||||
else
|
||||
dumpModule(moduleOp, "spatial1_premerge");
|
||||
|
||||
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include "ONNXToSpatialVerifier.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||
if (!computes.empty() || !computeBatches.empty())
|
||||
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
|
||||
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
||||
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
||||
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
||||
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
|
||||
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
||||
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|
||||
|| !materializers.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
sourceLocs.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
auto newCompute = spatial::SpatGraphCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
||||
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
||||
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
|
||||
op.dropAllUses();
|
||||
rewriter.eraseOp(&op);
|
||||
}
|
||||
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
|
||||
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
|
||||
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "Common/IR/WeightUtils.hpp"
|
||||
@@ -13,6 +15,8 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
|
||||
|
||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
func.walk([&](Operation* op) {
|
||||
if (!hasWeightAlways(op))
|
||||
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
|
||||
continue;
|
||||
|
||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
||||
illegalOp->emitOpError()
|
||||
<< kPhaseMarker
|
||||
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
|
||||
return candidate && (®ion == candidate || region.isAncestor(candidate));
|
||||
}
|
||||
|
||||
bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
bool isValueDefinedInsideRegion(Value value, Region& region) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isLegalExternalCapture(Value value, Region& region) {
|
||||
if (isValueDefinedInsideRegion(value, region))
|
||||
return true;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
Region& body = compute.getBody();
|
||||
body.walk([&](Operation* nestedOp) {
|
||||
for (OpOperand& operand : nestedOp->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isLegalExternalCapture(value, body))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diag =
|
||||
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
|
||||
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
|
||||
diag << " (type " << value.getType() << ")";
|
||||
if (definingOp)
|
||||
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
|
||||
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
||||
if (Operation* owner = blockArg.getOwner()->getParentOp())
|
||||
diag.attachNote(owner->getLoc())
|
||||
<< "external block argument belongs to " << owner->getName().getStringRef();
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
bool isLegalHostBackedValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return isa<BlockArgument>(value);
|
||||
|
||||
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
||||
return false;
|
||||
|
||||
return definingOp->getDialect()->getNamespace() != "spat";
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
||||
ValueRange inputs,
|
||||
bool allowChannelReceiveInputs,
|
||||
StringRef kind,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
||||
unsigned currentInputIndex = inputIndex;
|
||||
template <typename ComputeOpTy>
|
||||
void verifyScheduledInputs(ComputeOpTy compute,
|
||||
bool allowChannelReceiveInputs,
|
||||
StringRef kind,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||
Operation* definingOp = input.getDefiningOp();
|
||||
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
||||
continue;
|
||||
if (isLegalHostBackedValue(input))
|
||||
continue;
|
||||
|
||||
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
||||
<< kind << " input #" << currentInputIndex
|
||||
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
||||
"spat.channel_receive"
|
||||
: " must come from the host");
|
||||
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diag = illegalOp->emitOpError()
|
||||
<< kPhaseMarker << " " << kind << " input #" << inputIndex
|
||||
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
|
||||
: " must come from the host");
|
||||
if (definingOp)
|
||||
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
||||
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
||||
Region& region,
|
||||
StringRef kind,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (!isa<TensorType>(value.getType()))
|
||||
continue;
|
||||
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
||||
continue;
|
||||
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<func::ReturnOp,
|
||||
spatial::SpatGraphCompute,
|
||||
spatial::SpatGraphComputeBatch,
|
||||
spatial::SpatConv2DPlanOp,
|
||||
spatial::SpatReluPlanOp,
|
||||
spatial::SpatReconciliatorOp,
|
||||
spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||
continue;
|
||||
}
|
||||
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kPhaseMarker
|
||||
<< " explicit channel communication is not expected before merge materialization";
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError()
|
||||
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
||||
<< "values";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
||||
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
||||
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
});
|
||||
}
|
||||
checkWeightUseChains(funcOp, diagnostics);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
|
||||
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
|
||||
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
|
||||
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
|
||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
|
||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
|
||||
|
||||
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
verifyLogicalTopLevelOps(funcOp, diagnostics);
|
||||
checkWeightUseChains(funcOp, diagnostics);
|
||||
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||
return failure();
|
||||
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
(void) verifyComputeLikeInputs(
|
||||
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
||||
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
||||
}
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
||||
computeBatchOp.getInputs(),
|
||||
/*allowChannelReceiveInputs=*/false,
|
||||
"spat.compute_batch",
|
||||
diagnostics);
|
||||
verifyNoExternalTensorCaptures(
|
||||
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
||||
}
|
||||
|
||||
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
||||
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||
return failure();
|
||||
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
|
||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
|
||||
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
||||
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
Location loc = reluOp.getLoc();
|
||||
Type resultType = reluOp.getResult().getType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
||||
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
||||
});
|
||||
rewriter.replaceOp(reluOp, computeOp);
|
||||
auto reluPlan = spatial::SpatReluPlanOp::create(
|
||||
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
|
||||
rewriter.replaceOp(reluOp, reluPlan.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
||||
}
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
|
||||
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
auto newCompute = spatial::SpatGraphCompute::create(
|
||||
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
||||
if (failed(laneCountAttr))
|
||||
return failure();
|
||||
auto newCompute = spatial::SpatComputeBatch::create(
|
||||
auto newCompute = spatial::SpatGraphComputeBatch::create(
|
||||
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::FailureOr<mlir::Value>
|
||||
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
|
||||
std::optional<mlir::Value> rowStripInput,
|
||||
bool emitRowStripLayout,
|
||||
mlir::PatternRewriter& rewriter);
|
||||
|
||||
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,200 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static constexpr StringLiteral kLogicalLayout = "nchw";
|
||||
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
|
||||
|
||||
enum class SelectedLayout {
|
||||
DenseNchw,
|
||||
NchwRowStrip,
|
||||
};
|
||||
|
||||
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
|
||||
auto it = layouts.find(value);
|
||||
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
|
||||
}
|
||||
|
||||
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
|
||||
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
|
||||
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (usesSelectedRowStrip(user, layouts))
|
||||
continue;
|
||||
// Dense-only users must be materialized explicitly.
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
const int64_t channels = type.getDimSize(1);
|
||||
const int64_t height = type.getDimSize(2);
|
||||
const int64_t width = type.getDimSize(3);
|
||||
offsets.reserve(height * 4);
|
||||
sizes.reserve(height * 4);
|
||||
for (int64_t row = 0; row < height; ++row) {
|
||||
offsets.append({0, 0, row, 0});
|
||||
sizes.append({1, channels, 1, width});
|
||||
}
|
||||
return {offsets, sizes};
|
||||
}
|
||||
|
||||
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
|
||||
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
|
||||
if (inputLayout == SelectedLayout::NchwRowStrip)
|
||||
return succeeded(canConsumeAndProduceRowStrip(convPlan));
|
||||
return succeeded(canLowerConvPlanToRowStrip(convPlan));
|
||||
}
|
||||
|
||||
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
|
||||
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
if (!canSelectConvRowStrip(convPlan, layouts))
|
||||
return SelectedLayout::DenseNchw;
|
||||
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
|
||||
return SelectedLayout::DenseNchw;
|
||||
return SelectedLayout::NchwRowStrip;
|
||||
}
|
||||
|
||||
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
|
||||
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
|
||||
return SelectedLayout::DenseNchw;
|
||||
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
|
||||
return SelectedLayout::DenseNchw;
|
||||
return SelectedLayout::NchwRowStrip;
|
||||
}
|
||||
|
||||
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
|
||||
auto outputType = cast<RankedTensorType>(value.getType());
|
||||
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
||||
return spatial::SpatReconciliatorOp::create(rewriter,
|
||||
value.getLoc(),
|
||||
outputType,
|
||||
value,
|
||||
rewriter.getStringAttr(kLogicalLayout),
|
||||
rewriter.getStringAttr(kRowStripLayout),
|
||||
rewriter.getDenseI64ArrayAttr(offsets),
|
||||
rewriter.getDenseI64ArrayAttr(sizes),
|
||||
rewriter.getStringAttr(kRowStripIndexMap));
|
||||
}
|
||||
|
||||
static void materializeDenseUses(IRRewriter& rewriter,
|
||||
Value layoutValue,
|
||||
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||
SmallVector<OpOperand*> denseUses;
|
||||
for (OpOperand& use : layoutValue.getUses()) {
|
||||
if (usesSelectedRowStrip(use.getOwner(), layouts))
|
||||
continue;
|
||||
denseUses.push_back(&use);
|
||||
}
|
||||
|
||||
for (OpOperand* use : denseUses) {
|
||||
Operation* owner = use->getOwner();
|
||||
rewriter.setInsertionPoint(owner);
|
||||
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
|
||||
owner->getLoc(),
|
||||
use->get().getType(),
|
||||
use->get(),
|
||||
rewriter.getStringAttr(kLogicalLayout),
|
||||
rewriter.getStringAttr(kRowStripLayout),
|
||||
rewriter.getStringAttr(kDenseLayout));
|
||||
use->set(materialized.getResult());
|
||||
}
|
||||
}
|
||||
|
||||
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
|
||||
|
||||
StringRef getArgument() const override { return "spatial-layout-planning"; }
|
||||
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
|
||||
|
||||
void runOnOperation() override {
|
||||
auto entryFunc = getPimEntryFunc(getOperation());
|
||||
if (failed(entryFunc)) {
|
||||
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
IRRewriter rewriter(&getContext());
|
||||
llvm::DenseMap<Value, SelectedLayout> layouts;
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
|
||||
if (layouts[convPlan.getResult()] != selected) {
|
||||
layouts[convPlan.getResult()] = selected;
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
|
||||
if (layouts[reluPlan.getResult()] != selected) {
|
||||
layouts[reluPlan.getResult()] = selected;
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
Value producedValue;
|
||||
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
|
||||
producedValue = convPlan.getResult();
|
||||
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
|
||||
producedValue = reluPlan.getResult();
|
||||
else
|
||||
continue;
|
||||
|
||||
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPointAfter(&op);
|
||||
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
|
||||
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
|
||||
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
|
||||
}
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -102,7 +102,7 @@ static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
@@ -171,7 +171,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
|
||||
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
|
||||
return operandNumber - inputBegin;
|
||||
};
|
||||
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||
return getInputIndex(owner, compute.getInputs().size());
|
||||
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
|
||||
return getInputIndex(owner, computeBatch.getInputs().size());
|
||||
|
||||
return std::nullopt;
|
||||
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument;
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
|
||||
auto computeArg = compute.getInputArgument(inputIndex);
|
||||
assert(computeArg && "expected compute input block argument");
|
||||
bodyArgument = *computeArg;
|
||||
}
|
||||
else {
|
||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
assert(batchArg && "expected compute_batch input block argument");
|
||||
bodyArgument = *batchArg;
|
||||
}
|
||||
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
bodyArgument.replaceAllUsesWith(replacement);
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
||||
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||
compute.getInputsMutable().erase(inputIndex);
|
||||
else
|
||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
body.eraseArgument(bodyArgIndex);
|
||||
rewriter.finalizeOpModification(owner);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
||||
}
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||
auto checkedCoreId =
|
||||
@@ -66,7 +66,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
|
||||
return *checkedCoreId;
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
@@ -104,13 +104,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
}))
|
||||
return false;
|
||||
|
||||
@@ -145,7 +145,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
@@ -10,6 +10,14 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
|
||||
for (NamedAttribute attr : source->getAttrs()) {
|
||||
StringRef name = attr.getName().strref();
|
||||
if (name.starts_with("raptor."))
|
||||
target->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
||||
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
||||
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
Value received = pim::PimReceiveOp::create(
|
||||
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
|
||||
.getOutput();
|
||||
auto receive = pim::PimReceiveOp::create(
|
||||
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
|
||||
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
|
||||
Value received = receive.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
return failure();
|
||||
|
||||
for (auto& uses : extractSliceOp->getUses()) {
|
||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||
if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
return failure();
|
||||
}
|
||||
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
|
||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
else {
|
||||
{
|
||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||
auto argUser = argUses.getOwner();
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
|
||||
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
|
||||
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
return failure();
|
||||
@@ -212,7 +212,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
}
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
||||
if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
|
||||
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||
return std::nullopt;
|
||||
|
||||
@@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
}
|
||||
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||
spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
@@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||
Operation* onlyUser = *op->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|
||||
|| isReturnHelperChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
@@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
|
||||
@@ -25,9 +25,11 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/IR/ShapeUtils.hpp"
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static bool isHostBackedMemRefValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return isa<memref::GetGlobalOp>(definingOp);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isHostBackedTensorValue(Value value) {
|
||||
while (Operation* definingOp = value.getDefiningOp()) {
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
|
||||
extractSliceOp.getMixedOffsets(),
|
||||
extractSliceOp.getStaticSizes(),
|
||||
extractSliceOp.getStaticStrides())) {
|
||||
return false;
|
||||
}
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
|
||||
return isHostBackedMemRefValue(toTensorOp.getBuffer());
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||
if (failed(sizeAttr))
|
||||
return failure();
|
||||
if (isHostBackedTensorValue(vector)) {
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
|
||||
funcOp.emitOpError(
|
||||
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
OperationFolder constantFolder(&getContext());
|
||||
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
||||
};
|
||||
|
||||
for (auto& op : funcOp.getBody().getOps())
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
||||
continue;
|
||||
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
||||
|
||||
@@ -41,8 +41,11 @@ private:
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
||||
lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
@@ -51,7 +54,7 @@ private:
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
Reference in New Issue
Block a user