This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -258,24 +258,23 @@ where
let (memory, crossbars) = core.get_memory_crossbar(); let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap(); let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width();
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!(
crossbar_byte_width % size_of::<M>() == 0,
"M not divisor of the crosbbar size"
);
let crossbar_height = crossbar.height(); let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height; let crossbar_stored_bytes = crossbar.stored_bytes();
let bytes_per_column = crossbar_height * size_of::<M>();
ensure!(bytes_per_column != 0, "crossbar height can not be zero");
ensure!(
crossbar_stored_bytes % bytes_per_column == 0,
"Stored crossbar bytes do not describe an integral number of columns"
);
let crossbar_elem_width = crossbar_stored_bytes / bytes_per_column;
ensure!(crossbar_elem_width != 0, "Crossbar contains no stored columns");
let loads = memory let loads = memory
.reserve_load(r1_val, crossbar_height * size_of::<F>())? .reserve_load(r1_val, crossbar_height * size_of::<F>())?
.execute_load::<F>()?; .execute_load::<F>()?;
let load = loads[0]; let load = loads[0];
let vec: Cow<[M]> = load.up(); let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0]; let matrix = crossbar.load::<M>(crossbar_stored_bytes)?[0];
// --- FAER IMPLEMENTATION --- // --- FAER IMPLEMENTATION ---
+61
View File
@@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
} }
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (result) {
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
}
return yieldedValue;
}
}
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
@@ -512,6 +528,24 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
continue; continue;
} }
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion();
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(selectedRegion.front().getTerminator());
if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands())
return mlir::failure();
value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) { if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType()); auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType()); auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
@@ -622,6 +656,33 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
continue; continue;
} }
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto thenYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto elseYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands()
|| result.getResultNumber() >= elseYield.getNumOperands()) {
return mlir::failure();
}
auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber()));
auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber()));
if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base)
return mlir::failure();
auto condition = compileIndexValueImpl(ifOp.getCondition());
if (failed(condition))
return mlir::failure();
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset};
return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))};
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) { if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType()); auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType()); auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
+18
View File
@@ -96,6 +96,24 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::init(false), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimDetectCommunicationDeadlock(
"pim-detect-communication-deadlock",
llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder(
"pim-materialize-scalar-fanout-global-order",
llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimTraceCommunicationMaterialization(
"pim-trace-communication-materialization",
llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
+3
View File
@@ -53,6 +53,9 @@ extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson; extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<bool> pimReportConvLowering; extern llvm::cl::opt<bool> pimReportConvLowering;
extern llvm::cl::opt<bool> pimDetectCommunicationDeadlock;
extern llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder;
extern llvm::cl::opt<bool> pimTraceCommunicationMaterialization;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;
+2
View File
@@ -29,6 +29,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) { if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass()); pm.addPass(createONNXToSpatialPass());
pm.addPass(createSpatialLayoutPlanningPass());
pm.addPass(createLowerSpatialPlansPass());
pm.addPass(createMergeComputeNodesPass()); pm.addPass(createMergeComputeNodesPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial")); pm.addPass(createMessagePass("Onnx lowered to Spatial"));
} }
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
SpatialLayoutPlanningPass.cpp
LowerSpatialPlansPass.cpp
Common/AttributeUtils.cpp Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp Common/IndexingUtils.cpp
@@ -9,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) { Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
if (tensors.size() == 1) if (tensors.size() == 1)
return tensors[0]; return tensors[0];
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
} }
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if /// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure. /// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn> template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter, auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::TypeRange resultTypes, mlir::TypeRange resultTypes,
mlir::ValueRange weights, mlir::ValueRange weights,
mlir::ValueRange inputs, mlir::ValueRange inputs,
BodyFn&& body) { BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights) for (mlir::Value weight : weights)
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure()); return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
} }
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp); return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
} }
} }
/// Builds a `spat.compute` whose body consumes the block arguments as a single /// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats. /// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn> template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter, auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::TypeRange resultTypes, mlir::TypeRange resultTypes,
mlir::ValueRange weights, mlir::ValueRange weights,
mlir::ValueRange inputs, mlir::ValueRange inputs,
BodyFn&& body) { BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights) for (mlir::Value weight : weights)
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure()); return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
} }
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp); return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
} }
} }
template <typename RewriterT, typename BodyFn> template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter, auto createSpatGraphComputeBatch(RewriterT& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::TypeRange resultTypes, mlir::TypeRange resultTypes,
int64_t laneCount, int64_t laneCount,
mlir::ValueRange weights, mlir::ValueRange weights,
mlir::ValueRange inputs, mlir::ValueRange inputs,
BodyFn&& body) { BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max()) if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure()); return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count"); auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr)) if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure()); return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs); auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()}; mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc}; mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args); std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp); return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
} }
else { else {
auto bodyResult = std::forward<BodyFn>(body)(args); auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp); rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure()); return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
} }
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp); return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
} }
} }
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute<NumInputs>(
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphComputeBatch(
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter, inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::Value source, mlir::Value source,
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
return computeOp.getResult(0); return computeOp.getResult(0);
} }
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter); mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
} }
SmallVector<Value> sliceTensor( SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice); ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size()); assert("Invalid axis" && axis < shape.size());
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
} }
SmallVector<Value> SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice); ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape)); assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1; size_t axis = shape[0] != 1 ? 0 : 1;
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
} }
DenseMap<CoreId, SmallVector<Value>> DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) { sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc); SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore; DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) { for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis, size_t axis,
int64_t sliceSize, int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter, mlir::PatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice, llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize, int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter, mlir::PatternRewriter& rewriter,
mlir::Location loc); mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the /// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry. /// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore( llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice( mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
@@ -0,0 +1,409 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
struct RowStripPhysicalValue {
Value physicalValue;
RankedTensorType logicalType;
SmallVector<int64_t, 16> fragmentOffsets;
SmallVector<int64_t, 16> fragmentSizes;
std::string indexMap;
};
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
Value value) {
auto it = rowStripValues.find(value);
if (it == rowStripValues.end())
return failure();
return it->second;
}
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
Value physicalValue) {
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!logicalType)
return reconciliator.emitOpError("requires ranked logical output type"), failure();
RowStripPhysicalValue value;
value.physicalValue = physicalValue;
value.logicalType = logicalType;
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
value.indexMap = reconciliator.getIndexMap().str();
return value;
}
static FailureOr<Value>
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
auto computeOp =
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
return computeOp.getResult(0);
}
static FailureOr<Value>
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
return failure();
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
return failure();
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
return failure();
const int64_t rank = rowStripValue.logicalType.getRank();
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
const int64_t packedWidth = packedType.getDimSize(1);
const int64_t packedChannels = packedType.getDimSize(2);
if (fragmentCount != packedType.getDimSize(0))
return failure();
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
return failure();
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
return failure();
}
auto packedSliceType =
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto expandedType =
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto logicalFragmentType =
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {rowStripValue.logicalType},
fragmentCount,
{},
ValueRange {rowStripValue.physicalValue},
[&](detail::SpatComputeBatchBodyArgs args) {
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> packedSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
Value packedSlice = tensor::ExtractSliceOp::create(
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2},
{3}
});
Value transposeInit =
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
Value logicalFragment =
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
.getResult()[0];
SmallVector<OpFoldResult> logicalOffsets {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedWidth)};
createParallelInsertSliceIntoBatchOutput(rewriter,
loc,
logicalFragment,
args.outputs.front(),
logicalOffsets,
logicalSizes,
getUnitStrides(rewriter, 4));
return success();
});
if (failed(batchOp))
return failure();
return batchOp->getResult(0);
}
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
StringRef getArgument() const override { return "lower-spatial-plans"; }
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
PatternRewriter rewriter(ctx);
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
return true;
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
signalPassFailure();
return false;
};
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
return;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
planOp,
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
/*emitRowStripLayout=*/true,
rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
if (failed(rowStripValue)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *rowStripValue;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered =
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected Spatial Conv plan");
signalPassFailure();
return;
}
rewriter.replaceOp(planOp, *lowered);
continue;
}
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (outputReconciliator == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
if (failed(output)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *output;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
auto computeOp = createSpatCompute<1>(
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
rewriter.replaceOp(planOp, computeOp.getResults());
continue;
}
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(materializeOp, materializeOp.getInput());
continue;
}
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
if (failed(rowStripValue)) {
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
signalPassFailure();
return;
}
rewriter.setInsertionPoint(materializeOp);
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
if (failed(dense)) {
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
signalPassFailure();
return;
}
rewriter.replaceOp(materializeOp, *dense);
continue;
}
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
continue;
}
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
signalPassFailure();
return;
}
if (!eraseAfterLowering.contains(reconciliatorOp)) {
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
signalPassFailure();
return;
}
}
}
bool erasedAny = true;
while (erasedAny) {
erasedAny = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (!eraseAfterLowering.contains(&op))
continue;
if (!op.use_empty())
continue;
eraseAfterLowering.erase(&op);
rewriter.eraseOp(&op);
erasedAny = true;
}
}
if (!eraseAfterLowering.empty()) {
for (Operation& op : funcOp.getBody().front())
if (eraseAfterLowering.contains(&op))
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
signalPassFailure();
return;
}
ConversionTarget helperTarget(*ctx);
helperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
RewritePatternSet helperPatterns(ctx);
populateGemmPatterns(helperPatterns, ctx);
populateTransposePatterns(helperPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
FrozenRewritePatternSet nestedHelperPatterns([&] {
RewritePatternSet patterns(ctx);
populateGemmPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
return patterns;
}());
ConversionTarget nestedHelperTarget(*ctx);
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
SmallVector<Operation*> computeLikeOps;
funcOp.walk([&](Operation* op) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
computeLikeOps.push_back(op);
});
for (Operation* op : computeLikeOps) {
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
}
if (!verifyLogicalPhase("after nested helper conversions"))
return;
bool hasIllegalOps = false;
moduleOp.walk([&](Operation* op) {
if (isa<ONNXEntryPointOp>(op))
return;
if (isa<spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(op)
|| op->getDialect()->getNamespace() == "onnx") {
op->emitOpError("operation must not remain after LowerSpatialPlans");
hasIllegalOps = true;
}
});
if (hasIllegalOps)
signalPassFailure();
else
dumpModule(moduleOp, "spatial1_premerge");
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
return;
}
};
} // namespace
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
} // namespace onnx_mlir
@@ -18,6 +18,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ONNXToSpatialVerifier.hpp"
using namespace mlir; using namespace mlir;
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
static void populateEmptyFunction(func::FuncOp funcOp) { static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
IRMapping mapper; IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>()); SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>()); SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
if (!computes.empty() || !computeBatches.empty()) SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|| !materializers.empty()) {
return; return;
}
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
sourceLocs.push_back(source.getLoc()); sourceLocs.push_back(source.getLoc());
} }
auto newCompute = spatial::SpatCompute::create( auto newCompute = spatial::SpatGraphCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands())) for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
rewriter.setInsertionPointToEnd(newBlock); rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps()) for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps())) for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) { if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
op.dropAllUses(); op.dropAllUses();
rewriter.eraseOp(&op); rewriter.eraseOp(&op);
} }
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
return; return;
} }
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx); ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect, earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
ConversionTarget postTarget(*ctx); ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect, postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
affine::AffineDialect, affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>( postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); }); [](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>( postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); [](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
RewritePatternSet postPatterns(ctx); RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx); populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) { if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc); populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) { if (failed(verifyONNXToSpatial(*entryFunc))) {
@@ -1,4 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp" #include "Common/IR/WeightUtils.hpp"
@@ -13,6 +15,8 @@ namespace onnx_mlir {
namespace { namespace {
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) { func.walk([&](Operation* op) {
if (!hasWeightAlways(op)) if (!hasWeightAlways(op))
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
continue; continue;
diagnostics.report(op, [&](Operation* illegalOp) { diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError( illegalOp->emitOpError()
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"); << kPhaseMarker
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
}); });
return; return;
} }
}); });
} }
Region* getParentRegion(Value value) { bool isRegionOrAncestorOf(Region& region, Region* candidate) {
if (auto blockArg = dyn_cast<BlockArgument>(value)) return candidate && (&region == candidate || region.isAncestor(candidate));
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
} }
bool isDefinedInsideRegion(Value value, Region& region) { bool isValueDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value); if (auto blockArg = dyn_cast<BlockArgument>(value))
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion)); return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
if (Operation* definingOp = value.getDefiningOp())
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
return false;
}
bool isLegalExternalCapture(Value value, Region& region) {
if (isValueDefinedInsideRegion(value, region))
return true;
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
template <typename ComputeOpTy>
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
Region& body = compute.getBody();
body.walk([&](Operation* nestedOp) {
for (OpOperand& operand : nestedOp->getOpOperands()) {
Value value = operand.get();
if (isLegalExternalCapture(value, body))
continue;
Operation* definingOp = value.getDefiningOp();
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag =
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
diag << " (type " << value.getType() << ")";
if (definingOp)
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (Operation* owner = blockArg.getOwner()->getParentOp())
diag.attachNote(owner->getLoc())
<< "external block argument belongs to " << owner->getName().getStringRef();
}
});
}
});
} }
bool isLegalHostBackedValue(Value value) { bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp(); Operation* definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
return isa<BlockArgument>(value); return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat"; return definingOp->getDialect()->getNamespace() != "spat";
} }
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp, template <typename ComputeOpTy>
ValueRange inputs, void verifyScheduledInputs(ComputeOpTy compute,
bool allowChannelReceiveInputs, bool allowChannelReceiveInputs,
StringRef kind, StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) { pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) { for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
unsigned currentInputIndex = inputIndex;
Operation* definingOp = input.getDefiningOp(); Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp)) if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue; continue;
if (isLegalHostBackedValue(input)) if (isLegalHostBackedValue(input))
continue; continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) { diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() InFlightDiagnostic diag = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex << kPhaseMarker << " " << kind << " input #" << inputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit " << (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
"spat.channel_receive" : " must come from the host");
: " must come from the host");
if (definingOp) if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName(); diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
}); });
return failure();
} }
return success();
} }
void verifyNoExternalTensorCaptures(Operation* ownerOp, void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
Region& region, for (Operation& op : funcOp.getOps()) {
StringRef kind, if (isa<func::ReturnOp,
pim::CappedDiagnosticReporter& diagnostics) { spatial::SpatGraphCompute,
region.walk([&](Operation* op) { spatial::SpatGraphComputeBatch,
for (OpOperand& operand : op->getOpOperands()) { spatial::SpatConv2DPlanOp,
Value value = operand.get(); spatial::SpatReluPlanOp,
if (!isa<TensorType>(value.getType())) spatial::SpatReconciliatorOp,
continue; spatial::SpatMaterializeLayoutOp>(&op)) {
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value)) continue;
continue; }
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
});
continue;
}
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker
<< " explicit channel communication is not expected before merge materialization";
});
continue;
}
if (isCompileTimeOp(&op))
continue;
Operation* definingOp = value.getDefiningOp(); diagnostics.report(&op, [&](Operation* illegalOp) {
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>()) illegalOp->emitOpError()
continue; << kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
});
}
}
diagnostics.report(ownerOp, [&](Operation* illegalOp) { void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor " for (Operation& op : funcOp.getOps()) {
<< "values"; if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
diagnostic.attachNote(op->getLoc()) diagnostics.report(&op, [&](Operation* illegalOp) {
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by " illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
}); });
} }
}); }
} }
} // namespace } // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics; pim::CappedDiagnosticReporter diagnostics;
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
for (Operation& op : funcOp.getOps()) { verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op)) for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
continue; verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
if (isCompileTimeOp(&op)) for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
continue; verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
diagnostics.report(&op, [](Operation* illegalOp) { verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
illegalOp->emitOpError( diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
return success(!diagnostics.hasFailure()); return success(!diagnostics.hasFailure());
} }
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) { LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics; pim::CappedDiagnosticReporter diagnostics;
verifyLogicalTopLevelOps(funcOp, diagnostics);
checkWeightUseChains(funcOp, diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
return success(!diagnostics.hasFailure());
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
(void) verifyComputeLikeInputs( pim::CappedDiagnosticReporter diagnostics;
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics); verifyScheduledTopLevelOps(funcOp, diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics); for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
} verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) { verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(), if (failed(verifyNoComputeBodyCaptures(funcOp)))
computeBatchOp.getInputs(), return failure();
/*allowChannelReceiveInputs=*/false, diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
return success(!diagnostics.hasFailure()); return success(!diagnostics.hasFailure());
} }
@@ -6,6 +6,8 @@
namespace onnx_mlir { namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp); mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp); mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp); bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp); void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir } // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc(); Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType(); Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1; auto reluPlan = spatial::SpatReluPlanOp::create(
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) { rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x); rewriter.replaceOp(reluOp, reluPlan.getResult());
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
return success(); return success();
} }
}; };
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
} }
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> { struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern; using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute); auto promoted = computePromotedOperands(compute);
if (failed(promoted)) if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front(); Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute); rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create( auto newCompute = spatial::SpatGraphCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs); rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
}; };
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR. // Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> { struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern; using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute); auto promoted = computePromotedOperands(compute);
if (failed(promoted)) if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count"); rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
if (failed(laneCountAttr)) if (failed(laneCountAttr))
return failure(); return failure();
auto newCompute = spatial::SpatComputeBatch::create( auto newCompute = spatial::SpatGraphComputeBatch::create(
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs); rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
auto laneArg = compute.getLaneArgument(); auto laneArg = compute.getLaneArgument();
if (!laneArg) if (!laneArg)
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
}); });
} }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include <optional>
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
mlir::FailureOr<mlir::Value>
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
std::optional<mlir::Value> rowStripInput,
bool emitRowStripLayout,
mlir::PatternRewriter& rewriter);
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
} // namespace onnx_mlir
@@ -0,0 +1,200 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kLogicalLayout = "nchw";
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
enum class SelectedLayout {
DenseNchw,
NchwRowStrip,
};
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
auto it = layouts.find(value);
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
}
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
return false;
}
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
for (Operation* user : value.getUsers()) {
if (usesSelectedRowStrip(user, layouts))
continue;
// Dense-only users must be materialized explicitly.
continue;
}
return true;
}
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
const int64_t channels = type.getDimSize(1);
const int64_t height = type.getDimSize(2);
const int64_t width = type.getDimSize(3);
offsets.reserve(height * 4);
sizes.reserve(height * 4);
for (int64_t row = 0; row < height; ++row) {
offsets.append({0, 0, row, 0});
sizes.append({1, channels, 1, width});
}
return {offsets, sizes};
}
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
if (inputLayout == SelectedLayout::NchwRowStrip)
return succeeded(canConsumeAndProduceRowStrip(convPlan));
return succeeded(canLowerConvPlanToRowStrip(convPlan));
}
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (!canSelectConvRowStrip(convPlan, layouts))
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
auto outputType = cast<RankedTensorType>(value.getType());
auto [offsets, sizes] = buildRowStripMetadata(outputType);
return spatial::SpatReconciliatorOp::create(rewriter,
value.getLoc(),
outputType,
value,
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getDenseI64ArrayAttr(offsets),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getStringAttr(kRowStripIndexMap));
}
static void materializeDenseUses(IRRewriter& rewriter,
Value layoutValue,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SmallVector<OpOperand*> denseUses;
for (OpOperand& use : layoutValue.getUses()) {
if (usesSelectedRowStrip(use.getOwner(), layouts))
continue;
denseUses.push_back(&use);
}
for (OpOperand* use : denseUses) {
Operation* owner = use->getOwner();
rewriter.setInsertionPoint(owner);
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
owner->getLoc(),
use->get().getType(),
use->get(),
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getStringAttr(kDenseLayout));
use->set(materialized.getResult());
}
}
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
StringRef getArgument() const override { return "spatial-layout-planning"; }
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
void runOnOperation() override {
auto entryFunc = getPimEntryFunc(getOperation());
if (failed(entryFunc)) {
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
llvm::DenseMap<Value, SelectedLayout> layouts;
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
if (layouts[convPlan.getResult()] != selected) {
layouts[convPlan.getResult()] = selected;
changed = true;
}
continue;
}
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
if (layouts[reluPlan.getResult()] != selected) {
layouts[reluPlan.getResult()] = selected;
changed = true;
}
continue;
}
}
}
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
Value producedValue;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
producedValue = convPlan.getResult();
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
producedValue = reluPlan.getResult();
else
continue;
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
continue;
rewriter.setInsertionPointAfter(&op);
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
} // namespace onnx_mlir
@@ -102,7 +102,7 @@ static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
return mapper.lookup(value); return mapper.lookup(value);
} }
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
size_t& fallbackCoreId) { size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
@@ -171,7 +171,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
} // namespace } // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
IRRewriter& rewriter) { IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc(); Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front(); Block& oldBlock = computeBatchOp.getBody().front();
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
return operandNumber - inputBegin; return operandNumber - inputBegin;
}; };
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
return getInputIndex(owner, compute.getInputs().size()); return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size()); return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt; return std::nullopt;
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Value replacement) { Value replacement) {
Block& body = owner->getRegion(0).front(); Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument; BlockArgument bodyArgument;
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) { if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
auto computeArg = compute.getInputArgument(inputIndex); auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument"); assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg; bodyArgument = *computeArg;
} }
else { else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex); auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument"); assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg; bodyArgument = *batchArg;
} }
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
rewriter.startOpModification(owner); rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement); bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
compute.getInputsMutable().erase(inputIndex); compute.getInputsMutable().erase(inputIndex);
else else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex); cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(bodyArgIndex); body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner); rewriter.finalizeOpModification(owner);
} }
@@ -55,7 +55,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
} }
} }
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id"); return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
auto checkedCoreId = auto checkedCoreId =
@@ -66,7 +66,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
return *checkedCoreId; return *checkedCoreId;
} }
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain, SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) { bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
@@ -104,13 +104,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success(); return success();
} }
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter, IRRewriter& rewriter,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false; return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user); return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
})) }))
return false; return false;
@@ -145,7 +145,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace } // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp, LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter, IRRewriter& rewriter,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
@@ -10,6 +10,14 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
for (NamedAttribute attr : source->getAttrs()) {
StringRef name = attr.getName().strref();
if (name.starts_with("raptor."))
target->setAttr(attr.getName(), attr.getValue());
}
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> { struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput()); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId()); auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult()); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
Value received = pim::PimReceiveOp::create( auto receive = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId()) rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
.getOutput(); copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
Value received = receive.getOutput();
rewriter.replaceOp(op, received); rewriter.replaceOp(op, received);
return success(); return success();
} }
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
return failure(); return failure();
for (auto& uses : extractSliceOp->getUses()) { for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) { if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber())) if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure(); return failure();
} }
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) { for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) { if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
replaceAndEraseDirectComputeLikeInput( replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]); rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
} }
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) { else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
else { else {
{ {
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) { if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) { if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
uses.set(mapSpatToExtract[spatCompute.getOperation()]); uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation()); rewriter.finalizeOpModification(spatCompute.getOperation());
} }
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) { else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner(); auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) { if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor); replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
} }
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) { else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName); return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
} }
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) { SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure(); return failure();
@@ -212,7 +212,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
} }
SmallVector<Operation*> helperChain; SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) { if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt; return std::nullopt;
@@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
} }
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath( raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) { spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter); return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
} }
@@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin(); Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain = isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser) isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser); || isReturnHelperChainOp(onlyUser);
} }
if (!isExclusivelyOwnedByReturnChain) if (!isExclusivelyOwnedByReturnChain)
@@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return; return;
} }
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
if (!computeOp.getInputs().empty()) if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs()) for (Value input : computeOp.getInputs())
@@ -25,9 +25,11 @@
#include <cassert> #include <cassert>
#include <utility> #include <utility>
#include "Common/IR/ShapeUtils.hpp"
#include "Common/IR/ConstantUtils.hpp" #include "Common/IR/ConstantUtils.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp" #include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp" #include "Conversion/SpatialToPim/Patterns.hpp"
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
.getOutput(); .getOutput();
} }
static bool isHostBackedMemRefValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return isa<memref::GetGlobalOp>(definingOp);
}
return false;
}
static bool isHostBackedTensorValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return false;
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
extractSliceOp.getMixedOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
return false;
}
value = extractSliceOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
return isHostBackedMemRefValue(toTensorOp.getBuffer());
return false;
}
return false;
}
static FailureOr<Value> static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType()); auto vectorType = cast<RankedTensorType>(vector.getType());
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size"); auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
if (isHostBackedTensorValue(vector)) {
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput(); return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
} }
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return; return;
} }
func::FuncOp funcOp = *entryFunc; func::FuncOp funcOp = *entryFunc;
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
funcOp.emitOpError(
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
signalPassFailure();
return;
}
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext()); OperationFolder constantFolder(&getContext());
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return; return;
} }
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) { if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core"); computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
signalPassFailure(); signalPassFailure();
return; return;
} }
} }
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) { for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
markOpToRemove(computeBatchOp); markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) { if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
}; };
for (auto& op : funcOp.getBody().getOps()) for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0) if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue; continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) { for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
@@ -41,8 +41,11 @@ private:
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder); lowerComputeOp(spatial::SpatScheduledCompute computeOp,
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter); mlir::IRRewriter& rewriter,
mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult { enum class ReturnPathLoweringResult {
Handled, Handled,
@@ -51,7 +54,7 @@ private:
}; };
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp, ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
mlir::OpResult result, mlir::OpResult result,
mlir::Value yieldValue, mlir::Value yieldValue,
mlir::IRRewriter& rewriter); mlir::IRRewriter& rewriter);
@@ -13,10 +13,13 @@ using namespace bufferization;
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue,
Location loc,
RewriterBase& rewriter,
const StaticValueKnowledge& knowledge) {
bool isContiguous = bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)); succeeded(resolveContiguousAddress(memrefValue, knowledge)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue)) if (isContiguous && isDeviceLocalPimAddress(memrefValue, knowledge))
return memrefValue; return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -32,7 +35,7 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
if (failed(sizeAttr)) if (failed(sizeAttr))
return failure(); return failure();
if (isHostBackedPimAddress(memrefValue)) { if (isHostBackedPimAddress(memrefValue, knowledge)) {
return PimMemCopyHostToDevOp::create( return PimMemCopyHostToDevOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr) rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput(); .getOutput();
@@ -3,10 +3,15 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
llvm::FailureOr<mlir::Value> llvm::FailureOr<mlir::Value>
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); materializeContiguousInputMemRef(mlir::Value memrefValue,
mlir::Location loc,
mlir::RewriterBase& rewriter,
const onnx_mlir::StaticValueKnowledge& knowledge = {});
mlir::Value mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
@@ -15,6 +15,26 @@ using namespace bufferization;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
static StaticValueKnowledge getEnclosingBufferizationKnowledge(Operation* op) {
StaticValueKnowledge knowledge;
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
knowledge.indexValues[coreBatchOp.getLaneArgument()] = 0;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
}
return knowledge;
}
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -148,7 +168,8 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguous =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguous)) if (failed(contiguous))
return failure(); return failure();
inputs.push_back(*contiguous); inputs.push_back(*contiguous);
@@ -182,7 +203,8 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput)) if (failed(contiguousInput))
return failure(); return failure();
@@ -410,7 +432,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput)) if (failed(contiguousInput))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -456,7 +479,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput)) if (failed(contiguousInput))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -497,10 +521,12 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); auto contiguousLhs =
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousLhs)) if (failed(contiguousLhs))
return failure(); return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); auto contiguousRhs =
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousRhs)) if (failed(contiguousRhs))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -534,10 +560,12 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter); auto contiguousLhs =
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousLhs)) if (failed(contiguousLhs))
return failure(); return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter); auto contiguousRhs =
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousRhs)) if (failed(contiguousRhs))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -574,7 +602,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter); auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput)) if (failed(contiguousInput))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -116,6 +116,36 @@ lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const
return success(); return success();
} }
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyHostToDevOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getHostSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getDeviceTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getHostSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceTarget(), knowledge);
if (!sourceIsHost || !targetIsDevice || targetIsHost || sourceIsDevice)
return copyOp.emitOpError("pim.memcp_hd requires a host-backed source and a device-local target");
return success();
}
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyDevToHostOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getDeviceSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getHostTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getHostTarget(), knowledge);
if (!targetIsHost || !sourceIsDevice || sourceIsHost || targetIsDevice)
return copyOp.emitOpError("pim.memcp_dh requires a device-local source and a host-backed target");
return success();
}
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
if (!sourceIsDevice || !targetIsDevice || sourceIsHost || targetIsHost)
return copyOp.emitOpError("pim.memcp requires device-local source and target operands");
return success();
}
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> { struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; } StringRef getArgument() const override { return "bufferize-pim"; }
@@ -129,6 +159,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
private: private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const; void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const; LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
LogicalResult verifyPimCopyAddressSpaces(ModuleOp moduleOp) const;
}; };
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) { static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
@@ -240,6 +271,10 @@ void PimBufferizationPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
return; return;
} }
if (failed(verifyPimCopyAddressSpaces(moduleOp))) {
signalPassFailure();
return;
}
annotateWeightsMemrefs(moduleOp, funcOp); annotateWeightsMemrefs(moduleOp, funcOp);
@@ -346,6 +381,31 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
return success(); return success();
} }
LogicalResult PimBufferizationPass::verifyPimCopyAddressSpaces(ModuleOp moduleOp) const {
bool hasFailure = false;
auto verifyWithKnowledge = [&](auto coreLikeOp, const StaticValueKnowledge& initialKnowledge) {
(void) walkPimCoreBlockStructurally(
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(&op); copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(&op);
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(&op);
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
return success();
});
};
moduleOp.walk([&](pim::PimCoreOp coreOp) { verifyWithKnowledge(coreOp, seedCoreKnowledge(coreOp)); });
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, 0);
verifyWithKnowledge(coreBatchOp, knowledge);
});
return success(!hasFailure);
}
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); } std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -96,8 +96,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>(); if (!mapOp->getParentOfType<pim::PimCoreOp>() && !mapOp->getParentOfType<pim::PimCoreBatchOp>())
if (!coreOp)
return failure(); return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType()); auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
@@ -5,6 +5,7 @@ add_pim_library(OMPimVerification
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
OMPimCommon OMPimCommon
OMPimCompilerOptions
OMPimBufferization OMPimBufferization
PimOps PimOps
SpatialOps SpatialOps
@@ -5,12 +5,17 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
@@ -143,6 +148,479 @@ static bool isHostAddressableValue(Value value, const StaticValueKnowledge& know
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp()); return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
} }
enum class CommunicationEventKind { Send, Receive };
struct CommunicationEvent {
CommunicationEventKind kind = CommunicationEventKind::Send;
int64_t coreId = 0;
int64_t peerCoreId = 0;
int64_t size = 0;
uint64_t ordinal = 0;
std::optional<int64_t> minChannelId;
std::string materializer;
std::optional<int64_t> traceId;
std::optional<int64_t> commOrder;
std::optional<int64_t> traceClassId;
std::optional<int64_t> traceBlockOrdinal;
std::string traceKind;
std::string tracePhase;
std::string traceClassKind;
std::string tracePayload;
std::string traceMessages;
std::string tracePrevOp;
std::string traceNextOp;
Operation* op = nullptr;
};
using CommunicationEventVector = SmallVector<CommunicationEvent, 0>;
static StringRef getCommunicationEventKindName(CommunicationEventKind kind) {
return kind == CommunicationEventKind::Send ? "send" : "receive";
}
constexpr StringLiteral kRaptorMinChannelIdAttr = "raptor.min_channel_id";
constexpr StringLiteral kRaptorMaterializerAttr = "raptor.materializer";
constexpr StringLiteral kRaptorCommOrderAttr = "raptor.comm_order";
constexpr StringLiteral kRaptorCommTraceIdAttr = "raptor.comm_trace_id";
constexpr StringLiteral kRaptorCommTraceKindAttr = "raptor.comm_trace_kind";
constexpr StringLiteral kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase";
constexpr StringLiteral kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id";
constexpr StringLiteral kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind";
constexpr StringLiteral kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal";
constexpr StringLiteral kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload";
constexpr StringLiteral kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages";
constexpr StringLiteral kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op";
constexpr StringLiteral kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op";
static std::optional<int64_t> getNearestIntegerAttr(Operation* op, StringRef name) {
for (Operation* current = op; current; current = current->getParentOp())
if (auto attr = current->getAttrOfType<IntegerAttr>(name))
return attr.getInt();
return std::nullopt;
}
static std::string getNearestStringAttr(Operation* op, StringRef name) {
for (Operation* current = op; current; current = current->getParentOp())
if (auto attr = current->getAttrOfType<StringAttr>(name))
return attr.getValue().str();
return "";
}
static std::string formatLocation(Location loc) {
std::string text;
llvm::raw_string_ostream os(text);
loc.print(os);
return os.str();
}
static std::string formatOperationSummary(Operation* op) {
std::string text;
llvm::raw_string_ostream os(text);
OpPrintingFlags flags;
flags.skipRegions();
flags.elideLargeElementsAttrs();
op->print(os, flags);
return os.str();
}
static std::string formatCommunicationEvent(const CommunicationEvent& event) {
std::string text;
llvm::raw_string_ostream os(text);
os << "core " << event.coreId << " " << getCommunicationEventKindName(event.kind) << " "
<< (event.kind == CommunicationEventKind::Send ? "to" : "from") << " " << event.peerCoreId
<< " size " << event.size << "B ordinal " << event.ordinal;
if (event.minChannelId)
os << " min_channel " << *event.minChannelId;
if (event.commOrder)
os << " comm_order " << *event.commOrder;
if (!event.materializer.empty())
os << " materializer " << event.materializer;
if (event.traceId)
os << " trace#" << *event.traceId;
if (!event.tracePhase.empty())
os << " phase " << event.tracePhase;
if (event.traceClassId)
os << " class " << event.traceClassKind << "#" << *event.traceClassId;
if (event.traceBlockOrdinal)
os << " block_ordinal " << *event.traceBlockOrdinal;
if (!event.tracePayload.empty())
os << " payload " << event.tracePayload;
if (!event.traceMessages.empty())
os << " messages {" << event.traceMessages << "}";
if (!event.tracePrevOp.empty() || !event.traceNextOp.empty())
os << " inserted_between [" << event.tracePrevOp << " | " << event.traceNextOp << "]";
return os.str();
}
static bool areMatchedCommunicationEvents(const CommunicationEvent& lhs, const CommunicationEvent& rhs) {
if (lhs.coreId != rhs.peerCoreId || lhs.peerCoreId != rhs.coreId || lhs.size != rhs.size)
return false;
return (lhs.kind == CommunicationEventKind::Send && rhs.kind == CommunicationEventKind::Receive)
|| (lhs.kind == CommunicationEventKind::Receive && rhs.kind == CommunicationEventKind::Send);
}
static std::optional<size_t> findMatchingCounterpartIndex(const CommunicationEventVector& events,
const CommunicationEvent& event,
size_t begin) {
for (size_t index = begin; index < events.size(); ++index)
if (areMatchedCommunicationEvents(event, events[index]))
return index;
return std::nullopt;
}
static void printCounterpartProbe(llvm::raw_ostream& os,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
const CommunicationEvent& blockedEvent) {
auto peerEventsIt = coreEvents.find(blockedEvent.peerCoreId);
if (peerEventsIt == coreEvents.end()) {
os << " no local stream was collected for peer core " << blockedEvent.peerCoreId << "\n";
return;
}
const CommunicationEventVector& peerEvents = peerEventsIt->second;
size_t peerPc = 0;
auto peerPcIt = programCounters.find(blockedEvent.peerCoreId);
if (peerPcIt != programCounters.end())
peerPc = peerPcIt->second;
os << " counterpart probe for " << formatCommunicationEvent(blockedEvent) << "\n";
os << " peer core " << blockedEvent.peerCoreId << " current pc " << peerPc << " of " << peerEvents.size()
<< "\n";
std::optional<size_t> nextMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, peerPc);
std::optional<size_t> anyMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, 0);
if (!nextMatch && !anyMatch) {
os << " no matching counterpart exists anywhere in the peer stream\n";
return;
}
if (!nextMatch && anyMatch) {
os << " matching counterpart exists only before the peer pc at ordinal " << *anyMatch
<< "; this usually means the static stream expansion or ordering metadata is inconsistent\n";
os << " " << formatCommunicationEvent(peerEvents[*anyMatch]) << "\n";
os << " op: " << formatOperationSummary(peerEvents[*anyMatch].op) << "\n";
return;
}
const CommunicationEvent& match = peerEvents[*nextMatch];
os << " next matching counterpart is at peer ordinal " << *nextMatch << " (distance +"
<< (*nextMatch >= peerPc ? *nextMatch - peerPc : 0) << ")\n";
os << " " << formatCommunicationEvent(match) << "\n";
os << " op: " << formatOperationSummary(match.op) << "\n";
if (*nextMatch == peerPc)
return;
os << " peer operations blocking before that counterpart:\n";
size_t end = std::min(peerEvents.size(), std::min(*nextMatch + static_cast<size_t>(1), peerPc + static_cast<size_t>(12)));
for (size_t index = peerPc; index < end; ++index) {
os << (index == peerPc ? " pc => " : " ") << "#" << index << " "
<< formatCommunicationEvent(peerEvents[index]) << "\n";
os << " op: " << formatOperationSummary(peerEvents[index].op) << "\n";
}
if (end <= *nextMatch)
os << " ... " << (*nextMatch - end + 1) << " more peer communication event(s) before the counterpart\n";
}
static CommunicationEvent makeCommunicationEvent(CommunicationEventKind kind,
int64_t coreId,
int64_t peerCoreId,
int64_t size,
uint64_t ordinal,
Operation* op) {
return CommunicationEvent {kind,
coreId,
peerCoreId,
size,
ordinal,
getNearestIntegerAttr(op, kRaptorMinChannelIdAttr),
getNearestStringAttr(op, kRaptorMaterializerAttr),
getNearestIntegerAttr(op, kRaptorCommTraceIdAttr),
getNearestIntegerAttr(op, kRaptorCommOrderAttr),
getNearestIntegerAttr(op, kRaptorCommTraceClassIdAttr),
getNearestIntegerAttr(op, kRaptorCommTraceBlockOrdinalAttr),
getNearestStringAttr(op, kRaptorCommTraceKindAttr),
getNearestStringAttr(op, kRaptorCommTracePhaseAttr),
getNearestStringAttr(op, kRaptorCommTraceClassKindAttr),
getNearestStringAttr(op, kRaptorCommTracePayloadAttr),
getNearestStringAttr(op, kRaptorCommTraceMessagesAttr),
getNearestStringAttr(op, kRaptorCommTracePrevOpAttr),
getNearestStringAttr(op, kRaptorCommTraceNextOpAttr),
op};
}
static LogicalResult appendCoreCommunicationEvents(Block& block,
int64_t coreId,
const StaticValueKnowledge& initialKnowledge,
SmallVectorImpl<CommunicationEvent>& events,
pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlock(block, initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto sendOp = dyn_cast<pim::PimSendOp>(&op)) {
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
if (failed(targetCoreId)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("cannot statically resolve send target core for PIM communication deadlock check");
});
return failure();
}
events.push_back(makeCommunicationEvent(CommunicationEventKind::Send,
coreId,
*targetCoreId,
sendOp.getSize(),
static_cast<uint64_t>(events.size()),
&op));
return success();
}
if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(&op)) {
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
if (failed(sourceCoreId)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("cannot statically resolve receive source core for PIM communication deadlock check");
});
return failure();
}
events.push_back(makeCommunicationEvent(CommunicationEventKind::Receive,
coreId,
*sourceCoreId,
receiveOp.getSize(),
static_cast<uint64_t>(events.size()),
&op));
return success();
}
return success();
});
}
static void printCommunicationWindow(llvm::raw_ostream& os,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
int64_t coreId,
size_t pc,
unsigned radius = 4) {
auto eventsIt = coreEvents.find(coreId);
if (eventsIt == coreEvents.end())
return;
const CommunicationEventVector& events = eventsIt->second;
size_t begin = pc > radius ? pc - radius : 0;
size_t end = std::min(events.size(), pc + static_cast<size_t>(radius) + 1);
os << " local stream for core " << coreId << " around pc " << pc << " of " << events.size() << ":\n";
for (size_t index = begin; index < end; ++index) {
os << (index == pc ? " => " : " ") << formatCommunicationEvent(events[index]) << "\n";
os << " op: " << formatOperationSummary(events[index].op) << "\n";
}
}
static void printCommunicationDeadlockReport(const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
ArrayRef<int64_t> cycle) {
llvm::errs() << "\n=== PIM static communication deadlock report ===\n";
llvm::errs() << "wait cycle:";
for (int64_t coreId : cycle)
llvm::errs() << " " << coreId;
if (!cycle.empty())
llvm::errs() << " -> " << cycle.front();
llvm::errs() << "\n\nblocked heads:\n";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
const CommunicationEvent& event = eventsIt->second[pcIt->second];
llvm::errs() << " " << formatCommunicationEvent(event) << "\n";
llvm::errs() << " loc: " << formatLocation(event.op->getLoc()) << "\n";
llvm::errs() << " op : " << formatOperationSummary(event.op) << "\n";
}
llvm::errs() << "\npeer counterpart probes:\n";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
printCounterpartProbe(llvm::errs(), coreEvents, programCounters, eventsIt->second[pcIt->second]);
}
llvm::errs() << "\nlocal communication streams:\n";
for (int64_t coreId : cycle) {
auto pcIt = programCounters.find(coreId);
if (pcIt == programCounters.end())
continue;
printCommunicationWindow(llvm::errs(), coreEvents, coreId, pcIt->second);
}
llvm::errs() << "=== end PIM static communication deadlock report ===\n\n";
}
static void emitCommunicationDeadlockCycle(ModuleOp moduleOp,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
ArrayRef<int64_t> cycle) {
printCommunicationDeadlockReport(coreEvents, programCounters, cycle);
auto diagnostic = moduleOp.emitError()
<< "PIM communication deadlock check found a blocking send/receive cycle while statically simulating the "
"expanded per-core communication streams; see the PIM static communication deadlock report above";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
const CommunicationEvent& event = eventsIt->second[pcIt->second];
Diagnostic& note = diagnostic.attachNote(event.op->getLoc());
note << formatCommunicationEvent(event);
if (!event.materializer.empty())
note << " emitted by " << event.materializer;
}
}
static FailureOr<SmallVector<int64_t>> findCommunicationWaitCycle(
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters) {
for (const auto& [startCoreId, events] : coreEvents) {
auto startPcIt = programCounters.find(startCoreId);
if (startPcIt == programCounters.end() || startPcIt->second >= events.size())
continue;
DenseMap<int64_t, size_t> positionInPath;
SmallVector<int64_t, 8> path;
int64_t currentCoreId = startCoreId;
while (true) {
auto eventsIt = coreEvents.find(currentCoreId);
auto pcIt = programCounters.find(currentCoreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
break;
auto positionIt = positionInPath.find(currentCoreId);
if (positionIt != positionInPath.end()) {
SmallVector<int64_t> cycle;
for (size_t index = positionIt->second; index < path.size(); ++index)
cycle.push_back(path[index]);
return cycle;
}
positionInPath[currentCoreId] = path.size();
path.push_back(currentCoreId);
currentCoreId = eventsIt->second[pcIt->second].peerCoreId;
}
}
return failure();
}
static LogicalResult verifyNoStaticCommunicationDeadlock(ModuleOp moduleOp,
pim::CappedDiagnosticReporter& diagnostics) {
DenseMap<int64_t, CommunicationEventVector> coreEvents;
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
int64_t coreId = coreOp.getCoreId();
if (failed(appendCoreCommunicationEvents(
coreOp.getBody().front(), coreId, StaticValueKnowledge {}, coreEvents[coreId], diagnostics)))
hasFailure = true;
continue;
}
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
SmallVector<int32_t> coreIds = getBatchCoreIds(coreBatchOp);
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
for (size_t lane = 0; lane < laneCount; ++lane) {
StaticValueKnowledge laneKnowledge;
laneKnowledge.indexValues[coreBatchOp.getLaneArgument()] = static_cast<int64_t>(lane);
for (unsigned inputIndex = 0; inputIndex < coreBatchOp.getInputs().size(); ++inputIndex)
laneKnowledge.aliases[coreBatchOp.getInputArgument(inputIndex)] = coreBatchOp.getInputs()[inputIndex];
SmallVector<int32_t> laneCoreIds = getLaneChunkCoreIds(coreIds, laneCount, static_cast<unsigned>(lane));
for (int32_t coreId : laneCoreIds) {
if (failed(appendCoreCommunicationEvents(
coreBatchOp.getBody().front(), coreId, laneKnowledge, coreEvents[coreId], diagnostics)))
hasFailure = true;
}
}
}
}
}
if (hasFailure)
return failure();
DenseMap<int64_t, size_t> programCounters;
for (const auto& [coreId, events] : coreEvents)
programCounters[coreId] = 0;
while (true) {
bool madeProgress = false;
for (const auto& [coreId, events] : coreEvents) {
size_t pc = programCounters[coreId];
if (pc >= events.size())
continue;
const CommunicationEvent& event = events[pc];
auto peerEventsIt = coreEvents.find(event.peerCoreId);
if (peerEventsIt == coreEvents.end())
continue;
size_t peerPc = programCounters[event.peerCoreId];
if (peerPc >= peerEventsIt->second.size())
continue;
const CommunicationEvent& peerEvent = peerEventsIt->second[peerPc];
if (!areMatchedCommunicationEvents(event, peerEvent))
continue;
++programCounters[coreId];
++programCounters[event.peerCoreId];
madeProgress = true;
}
if (madeProgress)
continue;
bool allDone = true;
for (const auto& [coreId, events] : coreEvents) {
if (programCounters[coreId] < events.size()) {
allDone = false;
break;
}
}
if (allDone)
return success();
auto cycle = findCommunicationWaitCycle(coreEvents, programCounters);
if (succeeded(cycle)) {
emitCommunicationDeadlockCycle(moduleOp, coreEvents, programCounters, *cycle);
return failure();
}
auto diagnostic = moduleOp.emitError()
<< "PIM communication deadlock check stalled without finding a closed wait cycle; this usually means a "
"send/receive peer is missing or ordered after a finished core";
for (const auto& [coreId, events] : coreEvents) {
size_t pc = programCounters[coreId];
if (pc >= events.size())
continue;
const CommunicationEvent& event = events[pc];
diagnostic.attachNote(event.op->getLoc()) << formatCommunicationEvent(event);
}
return failure();
}
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> { struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -212,11 +690,18 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
} }
} }
bool hasFailure = false;
if (pimDetectCommunicationDeadlock && failed(verifyNoStaticCommunicationDeadlock(moduleOp, diagnostics)))
hasFailure = true;
if (diagnostics.hasFailure()) { if (diagnostics.hasFailure()) {
diagnostics.emitSuppressedSummary(moduleOp, "verification failures"); diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
moduleOp.emitError("PIM codegen verification failed; see diagnostics above"); moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure(); hasFailure = true;
} }
if (hasFailure)
signalPassFailure();
} }
private: private:
+121 -12
View File
@@ -26,7 +26,7 @@ def SpatTensor :
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute", class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments, [SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> { DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights"; let summary = "Compute region with attached constant weights";
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>> std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>> ::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}]; }];
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
} }
def SpatComputeBatch : SpatOp<"compute_batch", def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments, [SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> { DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs"; let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument(); std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>> std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>> ::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}]; }];
}
let hasVerifier = 1; def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
} }
def SpatInParallelOp : SpatOp<"in_parallel", [ def SpatInParallelOp : SpatOp<"in_parallel", [
Pure, Pure,
Terminator, Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>, DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> { ] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch"; let summary = "Parallel combining terminator for resultful Spatial compute batches";
let regions = (region SizedRegion<1>:$region); let regions = (region SizedRegion<1>:$region);
@@ -159,6 +192,82 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
//===----------------------------------------------------------------------===//
// Planning
//===----------------------------------------------------------------------===//
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
let arguments = (ins
SpatTensor:$input,
SpatTensor:$weight,
Optional<SpatTensor>:$bias,
DenseI64ArrayAttr:$pads,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
I64Attr:$group,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
let summary = "Layout-aware ReLU planning op";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
let summary = "Passive logical-to-physical layout selection record";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$physicalLayout,
DenseI64ArrayAttr:$fragmentOffsets,
DenseI64ArrayAttr:$fragmentSizes,
StrAttr:$indexMap
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
let summary = "Explicit layout conversion or materialization barrier";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$sourcePhysicalLayout,
StrAttr:$targetPhysicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Communication // Communication
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
+235 -118
View File
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
} }
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) { void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) { if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return; return;
} }
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount}); if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
} }
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>; using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
return weights; return weights;
} }
} // namespace template <typename ComputeOpTy>
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); } return getBlockArgument(compute.getBody(), idx);
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), getWeights().size() + idx);
} }
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) { template <typename ComputeOpTy>
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
auto index = std::distance(getWeights().begin(), existing); return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
return { }
{*existing, *getWeightArgument(index)}
}; template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
auto index = std::distance(compute.getWeights().begin(), existing);
return {{*existing, *getComputeWeightArgument(compute, index)}};
} }
unsigned weightCount = getWeights().size(); unsigned weightCount = compute.getWeights().size();
unsigned inputCount = getInputs().size(); unsigned inputCount = compute.getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight}); compute.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes( setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount)); compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc); auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
if (!blockArg) if (!blockArg)
return std::nullopt; return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg); return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
} }
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) { template <typename ComputeBatchOpTy>
unsigned weightCount = getWeights().size(); std::optional<std::tuple<Value, BlockArgument>>
unsigned inputCount = getInputs().size(); insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
getOperation()->insertOperands(weightCount + idx, ValueRange {input}); if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
auto index = std::distance(batch.getWeights().begin(), existing);
return {{*existing, *batch.getWeightArgument(index)}};
}
unsigned weightCount = batch.getWeights().size();
unsigned inputCount = batch.getInputs().size();
batch.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes( setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1)); batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg) if (!blockArg)
return std::nullopt; return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
} }
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
FailureOr<std::tuple<OpResult, SpatCompute>> insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { unsigned weightCount = compute.getWeights().size();
if (idx > getNumResults()) unsigned inputCount = compute.getInputs().size();
return failure(); compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
rewriter.setInsertionPoint(getOperation()); compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end()); auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
resultTypes.insert(resultTypes.begin() + idx, type); if (!blockArg)
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs()); return std::nullopt;
newCompute->setAttrs((*this)->getAttrs()); return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
} }
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { template <typename ComputeOpTy>
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty()) if (region.empty())
return; return;
for (unsigned index = 0; index < getWeights().size(); ++index) for (unsigned index = 0; index < compute.getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index)) if (auto weightArg = compute.getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index) for (unsigned index = 0; index < compute.getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index)) if (auto inputArg = compute.getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
} }
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); } template <typename ComputeOpTy>
FailureOr<std::tuple<OpResult, ComputeOpTy>>
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > compute.getNumResults())
return failure();
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) { rewriter.setInsertionPoint(compute.getOperation());
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute =
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
newCompute->setAttrs(compute->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
compute.getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(compute.getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
template <typename ComputeBatchOpTy>
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > batch.getNumResults())
return failure();
rewriter.setInsertionPoint(batch.getOperation());
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
newBatch->setAttrs(batch->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
batch.getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(batch.getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
} // namespace
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
bool isAnySpatialComputeLike(Operation* op) {
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
}
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
return getComputeWeightArgument(*this, idx);
}
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx); return getBlockArgument(getBody(), 1 + idx);
} }
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx); return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
} }
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
} }
std::optional<std::tuple<Value, BlockArgument>> std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { return insertComputeBatchWeight(*this, idx, weight, loc);
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
} }
std::optional<std::tuple<Value, BlockArgument>>
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) { SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size(); unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size(); unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input}); getOperation()->insertOperands(weightCount + idx, ValueRange {input});
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
return std::nullopt; return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
} }
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>> return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
} }
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty()) if (region.empty())
return; return;
if (auto laneArg = getLaneArgument()) if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane"); setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) {
setNameFn(*outputArg, "out");
continue;
}
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
}
}
for (unsigned index = 0; index < getWeights().size(); ++index) std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
if (auto weightArg = getWeightArgument(index)) std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); return getBlockArgument(getBody(), 1 + idx);
}
for (unsigned index = 0; index < getInputs().size(); ++index) std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
if (auto inputArg = getInputArgument(index)) return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); }
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeBatchWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) { for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index); auto outputArg = getOutputArgument(index);
if (!outputArg) if (!outputArg)
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
builder.createBlock(bodyRegion); builder.createBlock(bodyRegion);
} }
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } OpResult SpatInParallelOp::getParentResult(int64_t idx) {
Operation* parent = getOperation()->getParentOp();
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
return parent->getResult(idx);
}
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); } llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
+16
View File
@@ -26,3 +26,19 @@
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
namespace onnx_mlir {
namespace spatial {
bool isGraphComputeLike(mlir::Operation* op);
bool isGraphBatchComputeLike(mlir::Operation* op);
bool isScheduledComputeLike(mlir::Operation* op);
bool isScheduledBatchComputeLike(mlir::Operation* op);
bool isAnySpatialComputeLike(mlir::Operation* op);
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
using SpatCompute = SpatGraphCompute;
using SpatComputeBatch = SpatGraphComputeBatch;
} // namespace spatial
} // namespace onnx_mlir
+260 -251
View File
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return success(); return success();
} }
template <typename ComputeOpTy>
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeOpTy>
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
template <typename ComputeBatchOpTy>
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
auto laneArg = op.getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
if (op.getNumResults() != 0) {
outputArgs.reserve(op.getNumResults());
for (unsigned index = 0; index < op.getNumResults(); ++index) {
auto outputArg = op.getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << op.getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (op.getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeBatchOpTy>
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
} // namespace } // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) { void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success(); return success();
} }
void SpatCompute::print(OpAsmPrinter& printer) { void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
SmallVector<Value> weightArgs; ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
weightArgs.reserve(getWeights().size()); return parseComputeLikeOp<SpatGraphCompute>(parser, result);
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
} }
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs; return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
} }
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
void SpatComputeBatch::print(OpAsmPrinter& printer) { ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
auto laneArg = getLaneArgument(); return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
} }
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0; return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
} }
void SpatInParallelOp::print(OpAsmPrinter& printer) { void SpatInParallelOp::print(OpAsmPrinter& printer) {
@@ -10,8 +10,9 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { template <typename ComputeOpTy>
Block& block = getBody().front(); LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = compute.getBody().front();
if (!llvm::hasSingleElement(block)) if (!llvm::hasSingleElement(block))
return failure(); return failure();
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
for (Value yieldedValue : yieldOp.getOperands()) { for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) { if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) { if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber())); results.push_back(compute.getOperand(blockArg.getArgNumber()));
continue; continue;
} }
} }
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
return success(); return success();
} }
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
+237 -76
View File
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
return shapedType.getShape(); return shapedType.getShape();
} }
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { template <typename ComputeBatchOpTy>
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
if (batchOp.getNumResults() == 0) if (batchOp.getNumResults() == 0)
return false; return false;
auto blockArg = dyn_cast<BlockArgument>(value); auto blockArg = dyn_cast<BlockArgument>(value);
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
return success(); return success();
} }
static bool isStaticIndexExpr(Value value) {
if (matchConstantIndexValue(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
return false;
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
}
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
return false;
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || matchConstantIndexValue(value)) if (value == laneArg || isStaticIndexExpr(value))
return true; return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>(); auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
} }
auto addOp = value.getDefiningOp<arith::AddIOp>(); auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp) if (addOp)
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
auto mulOp = value.getDefiningOp<arith::MulIOp>();
if (!mulOp)
return false; return false;
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs())) return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs())); || (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
} }
static LogicalResult static LogicalResult
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)) if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue; continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError() InFlightDiagnostic diagnostic =
<< kind << " body may only directly reference external constants"; ownerOp->emitOpError() << kind << " body may not capture external values";
diagnostic.attachNote(op->getLoc()) diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); << "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
<< operand.getOperandNumber() << " type=" << value.getType()
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
: "scalar"));
if (Operation* definingOp = value.getDefiningOp())
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
else if (auto blockArg = dyn_cast<BlockArgument>(value))
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
<< blockArg.getOwner()->getParentOp()->getName() << "'";
hasFailure = true; hasFailure = true;
} }
}); });
return success(!hasFailure); return success(!hasFailure);
} }
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { template <typename ComputeBatchOpTy>
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
if (batchOp.getNumResults() == 0) { if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp) if (!yieldOp)
@@ -344,144 +380,266 @@ LogicalResult SpatConcatOp::verify() {
return success(); return success();
} }
LogicalResult verifyComputeResultsUses(Operation* op) { static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation"); static bool isKnownPhysicalLayout(StringRef layout) {
if (!llvm::all_of(op->getResults(), [](Value result) { return layout == "dense_nchw" || layout == "nchw_row_strip";
return llvm::all_of(result.getUsers(), [](Operation* op) { }
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
}); static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
})) { auto inputType = dyn_cast<RankedTensorType>(input.getType());
return op->emitError("ComputeResult used directly inside another Compute"); auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!inputType || !outputType)
return op->emitOpError() << kind << " requires ranked tensor input and output types";
if (inputType.getElementType() != outputType.getElementType())
return op->emitOpError() << kind << " requires matching input/output element types";
return success();
}
LogicalResult SpatConv2DPlanOp::verify() {
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!inputType || !weightType || !outputType)
return emitError("requires ranked tensor input, weight, and output");
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
return emitError("requires rank-4 input, weight, and output tensors");
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (getPads().size() != 4)
return emitError("requires exactly four pad values");
if (getStrides().size() != 2)
return emitError("requires exactly two stride values");
if (getDilations().size() != 2)
return emitError("requires exactly two dilation values");
if (getGroup() < 1)
return emitError("requires group >= 1");
if (inputType.getElementType() != weightType.getElementType()
|| inputType.getElementType() != outputType.getElementType()) {
return emitError("requires matching input, weight, and output element types");
}
if (getBias()) {
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
if (!biasType)
return emitError("requires ranked tensor bias type");
if (biasType.getElementType() != outputType.getElementType())
return emitError("requires bias element type to match output element type");
} }
return success(); return success();
} }
LogicalResult SpatCompute::verify() { LogicalResult SpatReluPlanOp::verify() {
auto& block = getBody().front(); if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
unsigned expectedArgCount = getWeights().size() + getInputs().size(); return failure();
if (block.getNumArguments() != expectedArgCount) if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("compute body must have weight and input block arguments"); return emitError("requires a known logical layout");
return success();
}
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { LogicalResult SpatReconciliatorOp::verify() {
auto blockArg = getWeightArgument(weightIndex); if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
if (!blockArg || blockArg->getType() != weight.getType()) return failure();
return emitError("compute weight block argument types must match weight operand types exactly"); if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getPhysicalLayout()))
return emitError("requires a known physical layout");
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!logicalType)
return emitError("requires ranked tensor output");
auto offsets = getFragmentOffsets();
auto sizes = getFragmentSizes();
if (offsets.size() != sizes.size())
return emitError("fragment offset and size arrays must have the same length");
if (offsets.empty())
return success();
int64_t rank = logicalType.getRank();
if (rank <= 0 || offsets.size() % rank != 0)
return emitError("fragment metadata must be a whole number of rank-sized fragments");
ArrayRef<int64_t> shape = logicalType.getShape();
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
int64_t dim = index % rank;
int64_t offset = offsets[index];
int64_t size = sizes[index];
if (offset < 0 || size < 0)
return emitError("fragment offsets and sizes must be non-negative");
int64_t logicalDim = shape[dim];
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
return emitError("fragment bounds must stay within the logical tensor shape");
} }
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex); return success();
}
LogicalResult SpatMaterializeLayoutOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
return emitError("requires a known source physical layout");
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
return emitError("requires a known target physical layout");
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isAnySpatialComputeLike(op))
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !isAnySpatialComputeLike(op->getParentOp());
});
})) {
return op->emitError("compute result used directly inside another Spatial compute body");
}
return success();
}
template <typename ComputeOpTy>
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
auto& block = compute.getBody().front();
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return compute.emitOpError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto blockArg = compute.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
auto blockArg = compute.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType()) if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly"); return compute.emitOpError("compute input block argument types must match input operand types exactly");
} }
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp) if (!yieldOp)
return emitError("ComputeOp must have a single yield operation"); return compute.emitOpError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes(); auto resultTypes = compute.getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes(); auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) if (resultTypes.size() != yieldTypes.size())
return emitError("ComputeOp must have same number of results as yieldOp operands"); return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it); auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it); auto yieldType = std::get<1>(it);
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
return emitError("ComputeOp output must be of the same type as yieldOp operand"); return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) { if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) { if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
return emitError("ComputeOp output must have the same encoding as yieldOp operand"); return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
} }
else { else {
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one"); return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
} }
} }
else if (dyn_cast<RankedTensorType>(yieldType)) { else if (dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one"); return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
} }
} }
} }
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used"); return compute.emitOpError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(*this, "compute"))) if (failed(verifyStaticWeights(compute, opName)))
return failure(); return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
return failure(); return failure();
if (failed(verifyComputeResultsUses(this->getOperation()))) if (failed(verifyComputeResultsUses(compute.getOperation())))
return failure(); return failure();
return success(); return success();
} }
LogicalResult SpatComputeBatch::verify() { LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
int32_t count = getLaneCount();
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
template <typename ComputeBatchOpTy>
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
int32_t count = batch.getLaneCount();
if (count <= 0) if (count <= 0)
return emitError("laneCount must be positive"); return batch.emitOpError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count); auto laneCountSz = static_cast<size_t>(count);
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr); auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr) if (!coreIdsAttr)
return emitError("compute_batch coreIds attribute must be a dense i32 array"); return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz)) if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount"); return batch.emitOpError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; })) if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative"); return batch.emitOpError("compute_batch coreIds values must be non-negative");
DenseSet<int32_t> seenCoreIds; DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef()) for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second) if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be unique"); return batch.emitOpError("compute_batch coreIds values must be unique");
} }
Block& block = getBody().front(); Block& block = batch.getBody().front();
if (block.getNumArguments() == 0) if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument"); return batch.emitOpError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
if (block.getNumArguments() != expectedArgCount) if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results"); return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = getLaneArgument(); auto laneArg = batch.getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex()) if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type"); return batch.emitOpError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
auto blockArg = getWeightArgument(weightIndex); auto blockArg = batch.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType()) if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly"); return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
} }
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
auto blockArg = getInputArgument(inputIndex); auto blockArg = batch.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType()) if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly"); return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
} }
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
auto blockArg = getOutputArgument(resultIndex); auto blockArg = batch.getOutputArgument(resultIndex);
if (!blockArg || blockArg->getType() != resultType) if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly"); return batch.emitOpError("compute_batch output block argument types must match result types exactly");
} }
if (failed(verifyComputeResultsUses(this->getOperation()))) if (failed(verifyComputeResultsUses(batch.getOperation())))
return failure(); return failure();
if (failed(verifyStaticWeights(*this, "compute_batch"))) if (failed(verifyStaticWeights(batch, opName)))
return failure(); return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch"))) if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
return failure(); return failure();
return verifyBatchBody(*this, block); return verifyBatchBody(batch, block);
}
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
LogicalResult SpatScheduledComputeBatch::verify() {
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
} }
LogicalResult SpatInParallelOp::verify() { LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>(); Operation* parent = getOperation()->getParentOp();
if (!batchOp) if (!isAnySpatialComputeBatchLike(parent))
return emitOpError("expected spat.compute_batch parent"); return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
if (batchOp.getNumResults() == 0) if (parent->getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent"); return emitOpError("requires a resultful spat.compute_batch parent");
auto laneArg = batchOp.getLaneArgument(); std::optional<BlockArgument> laneArg;
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
laneArg = graphBatch.getLaneArgument();
else
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
if (!laneArg) if (!laneArg)
return emitOpError("expected compute_batch lane block argument"); return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) { for (Operation& op : getRegion().front().getOperations()) {
@@ -494,7 +652,10 @@ LogicalResult SpatInParallelOp::verify() {
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations) for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get())) if ((isa<SpatGraphComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|| (isa<SpatScheduledComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
return op.emitOpError("may only insert into a compute_batch output block argument"); return op.emitOpError("may only insert into a compute_batch output block argument");
} }
File diff suppressed because it is too large Load Diff
@@ -40,10 +40,10 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
using namespace onnx_mlir::compact_asm; using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute; using SpatCompute = spatial::SpatGraphCompute;
using SpatComputeBatch = spatial::SpatComputeBatch; using SpatComputeBatch = spatial::SpatGraphComputeBatch;
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) { if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id"); auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId)) if (failed(checkedCoreId))
@@ -209,7 +209,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
}; };
for (Operation& op : funcOp.getBody().front()) { for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) { if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody()); uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
@@ -229,7 +229,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
totalCrossbarCount += perInstanceCrossbarCount; totalCrossbarCount += perInstanceCrossbarCount;
continue; continue;
} }
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) { if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody()); uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount()); uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation()); uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
@@ -353,7 +353,17 @@ public:
void runOnOperation() override { void runOnOperation() override {
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(func); mergeTriviallyConnectedComputes(func);
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
signalPassFailure();
return;
}
const spatial::MergeScheduleResult* analysisResult = nullptr; const spatial::MergeScheduleResult* analysisResult = nullptr;
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult(); analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
@@ -367,8 +377,8 @@ public:
signalPassFailure(); signalPassFailure();
return; return;
} }
if (failed(verifySpatialCommunicationInvariants(func))) { if (failed(verifyScheduledSpatialInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed"); func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
signalPassFailure(); signalPassFailure();
return; return;
} }
+2
View File
@@ -8,6 +8,8 @@
namespace onnx_mlir { namespace onnx_mlir {
std::unique_ptr<mlir::Pass> createONNXToSpatialPass(); std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<mlir::Pass> createSpatialLayoutPlanningPass();
std::unique_ptr<mlir::Pass> createLowerSpatialPlansPass();
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass(); std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
+2
View File
@@ -72,6 +72,8 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
void PimAccelerator::registerPasses(int optLevel) const { void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
registerPass(createONNXToSpatialPass); registerPass(createONNXToSpatialPass);
registerPass(createSpatialLayoutPlanningPass);
registerPass(createLowerSpatialPlansPass);
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass); registerPass(createPimBufferizationPass);