DeadLock
This commit is contained in:
@@ -258,24 +258,23 @@ where
|
|||||||
|
|
||||||
let (memory, crossbars) = core.get_memory_crossbar();
|
let (memory, crossbars) = core.get_memory_crossbar();
|
||||||
let crossbar = crossbars.get_mut(group).unwrap();
|
let crossbar = crossbars.get_mut(group).unwrap();
|
||||||
let crossbar_stored_bytes = crossbar.stored_bytes();
|
|
||||||
let crossbar_byte_width = crossbar.width();
|
|
||||||
|
|
||||||
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
|
||||||
ensure!(
|
|
||||||
crossbar_byte_width % size_of::<M>() == 0,
|
|
||||||
"M not divisor of the crosbbar size"
|
|
||||||
);
|
|
||||||
|
|
||||||
let crossbar_height = crossbar.height();
|
let crossbar_height = crossbar.height();
|
||||||
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
let crossbar_stored_bytes = crossbar.stored_bytes();
|
||||||
|
let bytes_per_column = crossbar_height * size_of::<M>();
|
||||||
|
ensure!(bytes_per_column != 0, "crossbar height can not be zero");
|
||||||
|
ensure!(
|
||||||
|
crossbar_stored_bytes % bytes_per_column == 0,
|
||||||
|
"Stored crossbar bytes do not describe an integral number of columns"
|
||||||
|
);
|
||||||
|
let crossbar_elem_width = crossbar_stored_bytes / bytes_per_column;
|
||||||
|
ensure!(crossbar_elem_width != 0, "Crossbar contains no stored columns");
|
||||||
|
|
||||||
let loads = memory
|
let loads = memory
|
||||||
.reserve_load(r1_val, crossbar_height * size_of::<F>())?
|
.reserve_load(r1_val, crossbar_height * size_of::<F>())?
|
||||||
.execute_load::<F>()?;
|
.execute_load::<F>()?;
|
||||||
let load = loads[0];
|
let load = loads[0];
|
||||||
let vec: Cow<[M]> = load.up();
|
let vec: Cow<[M]> = load.up();
|
||||||
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
let matrix = crossbar.load::<M>(crossbar_stored_bytes)?[0];
|
||||||
|
|
||||||
// --- FAER IMPLEMENTATION ---
|
// --- FAER IMPLEMENTATION ---
|
||||||
|
|
||||||
|
|||||||
Submodule backend-simulators/pim/pimsim-nn updated: 6d3b898e6b...3e3442b663
+1
-1
Submodule onnx-mlir updated: eb54c2afc4...82018d7ce5
@@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
|||||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (result) {
|
||||||
|
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
|
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
|
||||||
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
||||||
|
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
}
|
||||||
|
return yieldedValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||||
@@ -512,6 +528,24 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge);
|
||||||
|
if (failed(condition))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion();
|
||||||
|
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(selectedRegion.front().getTerminator());
|
||||||
|
if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
@@ -622,6 +656,33 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto thenYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
|
||||||
|
auto elseYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
|
||||||
|
if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands()
|
||||||
|
|| result.getResultNumber() >= elseYield.getNumOperands()) {
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber()));
|
||||||
|
auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber()));
|
||||||
|
if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto condition = compileIndexValueImpl(ifOp.getCondition());
|
||||||
|
if (failed(condition))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
CompiledIndexExprNode selectExpr;
|
||||||
|
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
||||||
|
selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset};
|
||||||
|
return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))};
|
||||||
|
}
|
||||||
|
|
||||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
|
|||||||
@@ -96,6 +96,24 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
|||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimDetectCommunicationDeadlock(
|
||||||
|
"pim-detect-communication-deadlock",
|
||||||
|
llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder(
|
||||||
|
"pim-materialize-scalar-fanout-global-order",
|
||||||
|
llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimTraceCommunicationMaterialization(
|
||||||
|
"pim-trace-communication-materialization",
|
||||||
|
llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
|||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
extern llvm::cl::opt<bool> pimEmitJson;
|
extern llvm::cl::opt<bool> pimEmitJson;
|
||||||
extern llvm::cl::opt<bool> pimReportConvLowering;
|
extern llvm::cl::opt<bool> pimReportConvLowering;
|
||||||
|
extern llvm::cl::opt<bool> pimDetectCommunicationDeadlock;
|
||||||
|
extern llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder;
|
||||||
|
extern llvm::cl::opt<bool> pimTraceCommunicationMaterialization;
|
||||||
|
|
||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
|
|
||||||
if (pimEmissionTarget >= EmitSpatial) {
|
if (pimEmissionTarget >= EmitSpatial) {
|
||||||
pm.addPass(createONNXToSpatialPass());
|
pm.addPass(createONNXToSpatialPass());
|
||||||
|
pm.addPass(createSpatialLayoutPlanningPass());
|
||||||
|
pm.addPass(createLowerSpatialPlansPass());
|
||||||
pm.addPass(createMergeComputeNodesPass());
|
pm.addPass(createMergeComputeNodesPass());
|
||||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
Patterns/Tensor/Transpose.cpp
|
Patterns/Tensor/Transpose.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
|
SpatialLayoutPlanningPass.cpp
|
||||||
|
LowerSpatialPlansPass.cpp
|
||||||
Common/AttributeUtils.cpp
|
Common/AttributeUtils.cpp
|
||||||
Common/ComputeRegionBuilder.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
Common/IndexingUtils.cpp
|
Common/IndexingUtils.cpp
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
|
||||||
if (tensors.size() == 1)
|
if (tensors.size() == 1)
|
||||||
return tensors[0];
|
return tensors[0];
|
||||||
|
|
||||||
|
|||||||
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
|
|||||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
|
||||||
/// the body callback reports failure.
|
/// the body callback reports failure.
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value weight : weights)
|
for (mlir::Value weight : weights)
|
||||||
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
|
||||||
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||||
template <typename RewriterT, typename BodyFn>
|
template <typename RewriterT, typename BodyFn>
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value weight : weights)
|
for (mlir::Value weight : weights)
|
||||||
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RewriterT, typename BodyFn>
|
template <typename RewriterT, typename BodyFn>
|
||||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
auto createSpatGraphComputeBatch(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
int64_t laneCount,
|
int64_t laneCount,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
|
|
||||||
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
||||||
if (mlir::failed(laneCountAttr))
|
if (mlir::failed(laneCountAttr))
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
|
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||||
|
|
||||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||||
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
|||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
std::forward<BodyFn>(body)(args);
|
std::forward<BodyFn>(body)(args);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
rewriter.eraseOp(batchOp);
|
rewriter.eraseOp(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphCompute<NumInputs>(
|
||||||
|
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
int64_t laneCount,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphComputeBatch(
|
||||||
|
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::Value source,
|
mlir::Value source,
|
||||||
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
SmallVector<Value> sliceTensor(
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||||
assert("Invalid axis" && axis < shape.size());
|
assert("Invalid axis" && axis < shape.size());
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value>
|
SmallVector<Value>
|
||||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||||
assert("Not a vector" && isVectorShape(shape));
|
assert("Not a vector" && isVectorShape(shape));
|
||||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||||
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
|
|||||||
}
|
}
|
||||||
|
|
||||||
DenseMap<CoreId, SmallVector<Value>>
|
DenseMap<CoreId, SmallVector<Value>>
|
||||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
|
||||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||||
|
|||||||
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
|
|||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
size_t axis,
|
size_t axis,
|
||||||
int64_t sliceSize,
|
int64_t sliceSize,
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
int64_t sliceSize,
|
int64_t sliceSize,
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||||
/// current PIM target geometry.
|
/// current PIM target geometry.
|
||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
mlir::Value extractAxisSlice(
|
mlir::Value extractAxisSlice(
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
|
|||||||
@@ -0,0 +1,409 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||||
|
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||||
|
|
||||||
|
struct RowStripPhysicalValue {
|
||||||
|
Value physicalValue;
|
||||||
|
RankedTensorType logicalType;
|
||||||
|
SmallVector<int64_t, 16> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 16> fragmentSizes;
|
||||||
|
std::string indexMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
|
||||||
|
Value value) {
|
||||||
|
auto it = rowStripValues.find(value);
|
||||||
|
if (it == rowStripValues.end())
|
||||||
|
return failure();
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
|
||||||
|
Value physicalValue) {
|
||||||
|
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||||
|
if (!logicalType)
|
||||||
|
return reconciliator.emitOpError("requires ranked logical output type"), failure();
|
||||||
|
RowStripPhysicalValue value;
|
||||||
|
value.physicalValue = physicalValue;
|
||||||
|
value.logicalType = logicalType;
|
||||||
|
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
|
||||||
|
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
|
||||||
|
value.indexMap = reconciliator.getIndexMap().str();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
|
||||||
|
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
|
||||||
|
auto computeOp =
|
||||||
|
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
|
||||||
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
|
||||||
|
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
|
||||||
|
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t rank = rowStripValue.logicalType.getRank();
|
||||||
|
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
|
||||||
|
const int64_t packedWidth = packedType.getDimSize(1);
|
||||||
|
const int64_t packedChannels = packedType.getDimSize(2);
|
||||||
|
if (fragmentCount != packedType.getDimSize(0))
|
||||||
|
return failure();
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||||
|
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto packedSliceType =
|
||||||
|
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto expandedType =
|
||||||
|
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto logicalFragmentType =
|
||||||
|
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto batchOp = createSpatComputeBatch(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {rowStripValue.logicalType},
|
||||||
|
fragmentCount,
|
||||||
|
{},
|
||||||
|
ValueRange {rowStripValue.physicalValue},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
|
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> packedSizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
|
||||||
|
Value packedSlice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
|
||||||
|
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2},
|
||||||
|
{3}
|
||||||
|
});
|
||||||
|
Value transposeInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
|
||||||
|
Value logicalFragment =
|
||||||
|
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
|
||||||
|
.getResult()[0];
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> logicalOffsets {
|
||||||
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packedChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packedWidth)};
|
||||||
|
createParallelInsertSliceIntoBatchOutput(rewriter,
|
||||||
|
loc,
|
||||||
|
logicalFragment,
|
||||||
|
args.outputs.front(),
|
||||||
|
logicalOffsets,
|
||||||
|
logicalSizes,
|
||||||
|
getUnitStrides(rewriter, 4));
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
if (failed(batchOp))
|
||||||
|
return failure();
|
||||||
|
return batchOp->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "lower-spatial-plans"; }
|
||||||
|
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp moduleOp = getOperation();
|
||||||
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
|
if (failed(entryFunc)) {
|
||||||
|
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
PatternRewriter rewriter(ctx);
|
||||||
|
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
|
||||||
|
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
|
||||||
|
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||||
|
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||||
|
return true;
|
||||||
|
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
|
||||||
|
signalPassFailure();
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
|
||||||
|
return;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||||
|
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
|
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||||
|
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||||
|
});
|
||||||
|
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||||
|
planOp,
|
||||||
|
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
|
||||||
|
/*emitRowStripLayout=*/true,
|
||||||
|
rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
|
||||||
|
if (failed(rowStripValue)) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rowStripValues[reconciliator.getResult()] = *rowStripValue;
|
||||||
|
eraseAfterLowering.insert(planOp);
|
||||||
|
eraseAfterLowering.insert(reconciliator);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered =
|
||||||
|
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected Spatial Conv plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(planOp, *lowered);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||||
|
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||||
|
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
|
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||||
|
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||||
|
});
|
||||||
|
if (outputReconciliator == planOp.getResult().getUsers().end()) {
|
||||||
|
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
|
||||||
|
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
|
||||||
|
if (failed(output)) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rowStripValues[reconciliator.getResult()] = *output;
|
||||||
|
eraseAfterLowering.insert(planOp);
|
||||||
|
eraseAfterLowering.insert(reconciliator);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
|
||||||
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(planOp, computeOp.getResults());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||||
|
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
|
||||||
|
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
|
||||||
|
rewriter.replaceOp(materializeOp, materializeOp.getInput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|
||||||
|
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
|
||||||
|
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||||
|
if (failed(rowStripValue)) {
|
||||||
|
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPoint(materializeOp);
|
||||||
|
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
|
||||||
|
if (failed(dense)) {
|
||||||
|
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(materializeOp, *dense);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
|
||||||
|
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
|
||||||
|
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
|
||||||
|
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!eraseAfterLowering.contains(reconciliatorOp)) {
|
||||||
|
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool erasedAny = true;
|
||||||
|
while (erasedAny) {
|
||||||
|
erasedAny = false;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (!eraseAfterLowering.contains(&op))
|
||||||
|
continue;
|
||||||
|
if (!op.use_empty())
|
||||||
|
continue;
|
||||||
|
eraseAfterLowering.erase(&op);
|
||||||
|
rewriter.eraseOp(&op);
|
||||||
|
erasedAny = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!eraseAfterLowering.empty()) {
|
||||||
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
if (eraseAfterLowering.contains(&op))
|
||||||
|
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ConversionTarget helperTarget(*ctx);
|
||||||
|
helperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect,
|
||||||
|
func::FuncDialect>();
|
||||||
|
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||||
|
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||||
|
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||||
|
|
||||||
|
RewritePatternSet helperPatterns(ctx);
|
||||||
|
populateGemmPatterns(helperPatterns, ctx);
|
||||||
|
populateTransposePatterns(helperPatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FrozenRewritePatternSet nestedHelperPatterns([&] {
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateGemmPatterns(patterns, ctx);
|
||||||
|
populateTransposePatterns(patterns, ctx);
|
||||||
|
return patterns;
|
||||||
|
}());
|
||||||
|
ConversionTarget nestedHelperTarget(*ctx);
|
||||||
|
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect,
|
||||||
|
func::FuncDialect>();
|
||||||
|
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||||
|
SmallVector<Operation*> computeLikeOps;
|
||||||
|
funcOp.walk([&](Operation* op) {
|
||||||
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
|
||||||
|
computeLikeOps.push_back(op);
|
||||||
|
});
|
||||||
|
for (Operation* op : computeLikeOps) {
|
||||||
|
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
|
||||||
|
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!verifyLogicalPhase("after nested helper conversions"))
|
||||||
|
return;
|
||||||
|
bool hasIllegalOps = false;
|
||||||
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
if (isa<ONNXEntryPointOp>(op))
|
||||||
|
return;
|
||||||
|
if (isa<spatial::SpatConv2DPlanOp,
|
||||||
|
spatial::SpatReluPlanOp,
|
||||||
|
spatial::SpatReconciliatorOp,
|
||||||
|
spatial::SpatMaterializeLayoutOp>(op)
|
||||||
|
|| op->getDialect()->getNamespace() == "onnx") {
|
||||||
|
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||||
|
hasIllegalOps = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (hasIllegalOps)
|
||||||
|
signalPassFailure();
|
||||||
|
else
|
||||||
|
dumpModule(moduleOp, "spatial1_premerge");
|
||||||
|
|
||||||
|
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -18,6 +18,7 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include "ONNXToSpatialVerifier.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
static void populateEmptyFunction(func::FuncOp funcOp) {
|
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
|
||||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
||||||
if (!computes.empty() || !computeBatches.empty())
|
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
||||||
|
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
||||||
|
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
|
||||||
|
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
||||||
|
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|
||||||
|
|| !materializers.empty()) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
sourceLocs.push_back(source.getLoc());
|
sourceLocs.push_back(source.getLoc());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newCompute = spatial::SpatCompute::create(
|
auto newCompute = spatial::SpatGraphCompute::create(
|
||||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||||
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||||
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||||
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
for (Operation& op : funcOp.getOps())
|
for (Operation& op : funcOp.getOps())
|
||||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||||
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||||
|
|
||||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
|
||||||
op.dropAllUses();
|
op.dropAllUses();
|
||||||
rewriter.eraseOp(&op);
|
rewriter.eraseOp(&op);
|
||||||
}
|
}
|
||||||
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
ConversionTarget earlyPostTarget(*ctx);
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
ConversionTarget postTarget(*ctx);
|
ConversionTarget postTarget(*ctx);
|
||||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
|
||||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
|
||||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
RewritePatternSet postPatterns(ctx);
|
RewritePatternSet postPatterns(ctx);
|
||||||
populatePostPatterns(postPatterns, ctx);
|
populatePostPatterns(postPatterns, ctx);
|
||||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||||
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
populateEmptyFunction(*entryFunc);
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "Common/IR/WeightUtils.hpp"
|
#include "Common/IR/WeightUtils.hpp"
|
||||||
@@ -13,6 +15,8 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
|
||||||
|
|
||||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
func.walk([&](Operation* op) {
|
func.walk([&](Operation* op) {
|
||||||
if (!hasWeightAlways(op))
|
if (!hasWeightAlways(op))
|
||||||
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError(
|
illegalOp->emitOpError()
|
||||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
<< kPhaseMarker
|
||||||
|
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Region* getParentRegion(Value value) {
|
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
return candidate && (®ion == candidate || region.isAncestor(candidate));
|
||||||
return blockArg.getOwner()->getParent();
|
|
||||||
if (Operation* definingOp = value.getDefiningOp())
|
|
||||||
return definingOp->getParentRegion();
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isDefinedInsideRegion(Value value, Region& region) {
|
bool isValueDefinedInsideRegion(Value value, Region& region) {
|
||||||
Region* parentRegion = getParentRegion(value);
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
|
||||||
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
|
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalExternalCapture(Value value, Region& region) {
|
||||||
|
if (isValueDefinedInsideRegion(value, region))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
Region& body = compute.getBody();
|
||||||
|
body.walk([&](Operation* nestedOp) {
|
||||||
|
for (OpOperand& operand : nestedOp->getOpOperands()) {
|
||||||
|
Value value = operand.get();
|
||||||
|
if (isLegalExternalCapture(value, body))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||||
|
InFlightDiagnostic diag =
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
|
||||||
|
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
|
||||||
|
diag << " (type " << value.getType() << ")";
|
||||||
|
if (definingOp)
|
||||||
|
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
|
||||||
|
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
||||||
|
if (Operation* owner = blockArg.getOwner()->getParentOp())
|
||||||
|
diag.attachNote(owner->getLoc())
|
||||||
|
<< "external block argument belongs to " << owner->getName().getStringRef();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLegalHostBackedValue(Value value) {
|
bool isLegalHostBackedValue(Value value) {
|
||||||
Operation* definingOp = value.getDefiningOp();
|
Operation* definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
return isa<BlockArgument>(value);
|
return isa<BlockArgument>(value);
|
||||||
|
|
||||||
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return definingOp->getDialect()->getNamespace() != "spat";
|
return definingOp->getDialect()->getNamespace() != "spat";
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
template <typename ComputeOpTy>
|
||||||
ValueRange inputs,
|
void verifyScheduledInputs(ComputeOpTy compute,
|
||||||
bool allowChannelReceiveInputs,
|
bool allowChannelReceiveInputs,
|
||||||
StringRef kind,
|
StringRef kind,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
unsigned currentInputIndex = inputIndex;
|
|
||||||
Operation* definingOp = input.getDefiningOp();
|
Operation* definingOp = input.getDefiningOp();
|
||||||
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
||||||
continue;
|
continue;
|
||||||
if (isLegalHostBackedValue(input))
|
if (isLegalHostBackedValue(input))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
InFlightDiagnostic diag = illegalOp->emitOpError()
|
||||||
<< kind << " input #" << currentInputIndex
|
<< kPhaseMarker << " " << kind << " input #" << inputIndex
|
||||||
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
|
||||||
"spat.channel_receive"
|
: " must come from the host");
|
||||||
: " must come from the host");
|
|
||||||
if (definingOp)
|
if (definingOp)
|
||||||
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
|
||||||
});
|
});
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
Region& region,
|
for (Operation& op : funcOp.getOps()) {
|
||||||
StringRef kind,
|
if (isa<func::ReturnOp,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
spatial::SpatGraphCompute,
|
||||||
region.walk([&](Operation* op) {
|
spatial::SpatGraphComputeBatch,
|
||||||
for (OpOperand& operand : op->getOpOperands()) {
|
spatial::SpatConv2DPlanOp,
|
||||||
Value value = operand.get();
|
spatial::SpatReluPlanOp,
|
||||||
if (!isa<TensorType>(value.getType()))
|
spatial::SpatReconciliatorOp,
|
||||||
continue;
|
spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||||
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
continue;
|
||||||
continue;
|
}
|
||||||
|
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker
|
||||||
|
<< " explicit channel communication is not expected before merge materialization";
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isCompileTimeOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
illegalOp->emitOpError()
|
||||||
continue;
|
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
for (Operation& op : funcOp.getOps()) {
|
||||||
<< "values";
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
|
||||||
diagnostic.attachNote(op->getLoc())
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
|
||||||
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
|
||||||
for (Operation& op : funcOp.getOps()) {
|
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
|
||||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
|
||||||
continue;
|
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
|
||||||
if (isCompileTimeOp(&op))
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||||
continue;
|
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
|
||||||
illegalOp->emitOpError(
|
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
|
||||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
checkWeightUseChains(funcOp, diagnostics);
|
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
|
||||||
|
|
||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
|
||||||
|
|
||||||
|
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
verifyLogicalTopLevelOps(funcOp, diagnostics);
|
||||||
|
checkWeightUseChains(funcOp, diagnostics);
|
||||||
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||||
|
return failure();
|
||||||
|
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
|
}
|
||||||
|
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
||||||
(void) verifyComputeLikeInputs(
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||||
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||||
}
|
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||||
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||||
computeBatchOp.getInputs(),
|
return failure();
|
||||||
/*allowChannelReceiveInputs=*/false,
|
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
|
||||||
"spat.compute_batch",
|
|
||||||
diagnostics);
|
|
||||||
verifyNoExternalTensorCaptures(
|
|
||||||
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
|
||||||
}
|
|
||||||
|
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
|
||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||||
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
|
||||||
|
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
|
||||||
|
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
|
|||||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
|
||||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
|||||||
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||||
Location loc = reluOp.getLoc();
|
Location loc = reluOp.getLoc();
|
||||||
Type resultType = reluOp.getResult().getType();
|
Type resultType = reluOp.getResult().getType();
|
||||||
constexpr size_t numInputs = 1;
|
auto reluPlan = spatial::SpatReluPlanOp::create(
|
||||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
|
||||||
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
rewriter.replaceOp(reluOp, reluPlan.getResult());
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
|
||||||
});
|
|
||||||
rewriter.replaceOp(reluOp, computeOp);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
|
||||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
|
||||||
auto promoted = computePromotedOperands(compute);
|
auto promoted = computePromotedOperands(compute);
|
||||||
if (failed(promoted))
|
if (failed(promoted))
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||||
Block& oldBlock = compute.getBody().front();
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
rewriter.setInsertionPointAfter(compute);
|
||||||
auto newCompute = spatial::SpatCompute::create(
|
auto newCompute = spatial::SpatGraphCompute::create(
|
||||||
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
|
||||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||||
auto promoted = computePromotedOperands(compute);
|
auto promoted = computePromotedOperands(compute);
|
||||||
if (failed(promoted))
|
if (failed(promoted))
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||||
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
||||||
if (failed(laneCountAttr))
|
if (failed(laneCountAttr))
|
||||||
return failure();
|
return failure();
|
||||||
auto newCompute = spatial::SpatComputeBatch::create(
|
auto newCompute = spatial::SpatGraphComputeBatch::create(
|
||||||
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
||||||
auto laneArg = compute.getLaneArgument();
|
auto laneArg = compute.getLaneArgument();
|
||||||
if (!laneArg)
|
if (!laneArg)
|
||||||
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::Value>
|
||||||
|
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
|
||||||
|
std::optional<mlir::Value> rowStripInput,
|
||||||
|
bool emitRowStripLayout,
|
||||||
|
mlir::PatternRewriter& rewriter);
|
||||||
|
|
||||||
|
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||||
|
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,200 @@
|
|||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr StringLiteral kLogicalLayout = "nchw";
|
||||||
|
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||||
|
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||||
|
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
|
||||||
|
|
||||||
|
enum class SelectedLayout {
|
||||||
|
DenseNchw,
|
||||||
|
NchwRowStrip,
|
||||||
|
};
|
||||||
|
|
||||||
|
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
|
||||||
|
auto it = layouts.find(value);
|
||||||
|
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
|
||||||
|
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
|
||||||
|
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
for (Operation* user : value.getUsers()) {
|
||||||
|
if (usesSelectedRowStrip(user, layouts))
|
||||||
|
continue;
|
||||||
|
// Dense-only users must be materialized explicitly.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
const int64_t channels = type.getDimSize(1);
|
||||||
|
const int64_t height = type.getDimSize(2);
|
||||||
|
const int64_t width = type.getDimSize(3);
|
||||||
|
offsets.reserve(height * 4);
|
||||||
|
sizes.reserve(height * 4);
|
||||||
|
for (int64_t row = 0; row < height; ++row) {
|
||||||
|
offsets.append({0, 0, row, 0});
|
||||||
|
sizes.append({1, channels, 1, width});
|
||||||
|
}
|
||||||
|
return {offsets, sizes};
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
|
||||||
|
if (inputLayout == SelectedLayout::NchwRowStrip)
|
||||||
|
return succeeded(canConsumeAndProduceRowStrip(convPlan));
|
||||||
|
return succeeded(canLowerConvPlanToRowStrip(convPlan));
|
||||||
|
}
|
||||||
|
|
||||||
|
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (!canSelectConvRowStrip(convPlan, layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
return SelectedLayout::NchwRowStrip;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
return SelectedLayout::NchwRowStrip;
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
|
||||||
|
auto outputType = cast<RankedTensorType>(value.getType());
|
||||||
|
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
||||||
|
return spatial::SpatReconciliatorOp::create(rewriter,
|
||||||
|
value.getLoc(),
|
||||||
|
outputType,
|
||||||
|
value,
|
||||||
|
rewriter.getStringAttr(kLogicalLayout),
|
||||||
|
rewriter.getStringAttr(kRowStripLayout),
|
||||||
|
rewriter.getDenseI64ArrayAttr(offsets),
|
||||||
|
rewriter.getDenseI64ArrayAttr(sizes),
|
||||||
|
rewriter.getStringAttr(kRowStripIndexMap));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void materializeDenseUses(IRRewriter& rewriter,
|
||||||
|
Value layoutValue,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
SmallVector<OpOperand*> denseUses;
|
||||||
|
for (OpOperand& use : layoutValue.getUses()) {
|
||||||
|
if (usesSelectedRowStrip(use.getOwner(), layouts))
|
||||||
|
continue;
|
||||||
|
denseUses.push_back(&use);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (OpOperand* use : denseUses) {
|
||||||
|
Operation* owner = use->getOwner();
|
||||||
|
rewriter.setInsertionPoint(owner);
|
||||||
|
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
|
||||||
|
owner->getLoc(),
|
||||||
|
use->get().getType(),
|
||||||
|
use->get(),
|
||||||
|
rewriter.getStringAttr(kLogicalLayout),
|
||||||
|
rewriter.getStringAttr(kRowStripLayout),
|
||||||
|
rewriter.getStringAttr(kDenseLayout));
|
||||||
|
use->set(materialized.getResult());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "spatial-layout-planning"; }
|
||||||
|
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto entryFunc = getPimEntryFunc(getOperation());
|
||||||
|
if (failed(entryFunc)) {
|
||||||
|
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
llvm::DenseMap<Value, SelectedLayout> layouts;
|
||||||
|
|
||||||
|
bool changed = true;
|
||||||
|
while (changed) {
|
||||||
|
changed = false;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||||
|
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
|
||||||
|
if (layouts[convPlan.getResult()] != selected) {
|
||||||
|
layouts[convPlan.getResult()] = selected;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||||
|
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
|
||||||
|
if (layouts[reluPlan.getResult()] != selected) {
|
||||||
|
layouts[reluPlan.getResult()] = selected;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
Value producedValue;
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
|
||||||
|
producedValue = convPlan.getResult();
|
||||||
|
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
|
||||||
|
producedValue = reluPlan.getResult();
|
||||||
|
else
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(&op);
|
||||||
|
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
|
||||||
|
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
|
||||||
|
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
|
||||||
|
}
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -102,7 +102,7 @@ static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
|||||||
return mapper.lookup(value);
|
return mapper.lookup(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
size_t& fallbackCoreId) {
|
size_t& fallbackCoreId) {
|
||||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
@@ -171,7 +171,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
Location loc = computeBatchOp.getLoc();
|
Location loc = computeBatchOp.getLoc();
|
||||||
Block& oldBlock = computeBatchOp.getBody().front();
|
Block& oldBlock = computeBatchOp.getBody().front();
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
|
|||||||
return operandNumber - inputBegin;
|
return operandNumber - inputBegin;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||||
return getInputIndex(owner, compute.getInputs().size());
|
return getInputIndex(owner, compute.getInputs().size());
|
||||||
|
|
||||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
|
||||||
return getInputIndex(owner, computeBatch.getInputs().size());
|
return getInputIndex(owner, computeBatch.getInputs().size());
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
Value replacement) {
|
Value replacement) {
|
||||||
Block& body = owner->getRegion(0).front();
|
Block& body = owner->getRegion(0).front();
|
||||||
BlockArgument bodyArgument;
|
BlockArgument bodyArgument;
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
|
||||||
auto computeArg = compute.getInputArgument(inputIndex);
|
auto computeArg = compute.getInputArgument(inputIndex);
|
||||||
assert(computeArg && "expected compute input block argument");
|
assert(computeArg && "expected compute input block argument");
|
||||||
bodyArgument = *computeArg;
|
bodyArgument = *computeArg;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
|
||||||
assert(batchArg && "expected compute_batch input block argument");
|
assert(batchArg && "expected compute_batch input block argument");
|
||||||
bodyArgument = *batchArg;
|
bodyArgument = *batchArg;
|
||||||
}
|
}
|
||||||
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
|
|
||||||
rewriter.startOpModification(owner);
|
rewriter.startOpModification(owner);
|
||||||
bodyArgument.replaceAllUsesWith(replacement);
|
bodyArgument.replaceAllUsesWith(replacement);
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||||
compute.getInputsMutable().erase(inputIndex);
|
compute.getInputsMutable().erase(inputIndex);
|
||||||
else
|
else
|
||||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||||
body.eraseArgument(bodyArgIndex);
|
body.eraseArgument(bodyArgIndex);
|
||||||
rewriter.finalizeOpModification(owner);
|
rewriter.finalizeOpModification(owner);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
|
||||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||||
auto checkedCoreId =
|
auto checkedCoreId =
|
||||||
@@ -66,7 +66,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
|
|||||||
return *checkedCoreId;
|
return *checkedCoreId;
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain,
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
bool requireReturnUse = true) {
|
bool requireReturnUse = true) {
|
||||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
@@ -104,13 +104,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
|
||||||
IRRewriter& rewriter,
|
IRRewriter& rewriter,
|
||||||
OperationFolder& constantFolder) {
|
OperationFolder& constantFolder) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
return false;
|
return false;
|
||||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
}))
|
}))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||||
IRRewriter& rewriter,
|
IRRewriter& rewriter,
|
||||||
OperationFolder& constantFolder) {
|
OperationFolder& constantFolder) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|||||||
@@ -10,6 +10,14 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
|
||||||
|
for (NamedAttribute attr : source->getAttrs()) {
|
||||||
|
StringRef name = attr.getName().strref();
|
||||||
|
if (name.starts_with("raptor."))
|
||||||
|
target->setAttr(attr.getName(), attr.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
|||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
||||||
|
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
|||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
Value received = pim::PimReceiveOp::create(
|
auto receive = pim::PimReceiveOp::create(
|
||||||
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
|
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
|
||||||
.getOutput();
|
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
|
||||||
|
Value received = receive.getOutput();
|
||||||
rewriter.replaceOp(op, received);
|
rewriter.replaceOp(op, received);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (auto& uses : extractSliceOp->getUses()) {
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||||
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
|
|
||||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
replaceAndEraseDirectComputeLikeInput(
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
{
|
{
|
||||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
auto argUser = argUses.getOwner();
|
auto argUser = argUses.getOwner();
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
|
|||||||
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain) {
|
SmallVectorImpl<Operation*>& helperChain) {
|
||||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -212,7 +212,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
|
||||||
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
@@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
|
|
||||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||||
Operation* onlyUser = *op->getUsers().begin();
|
Operation* onlyUser = *op->getUsers().begin();
|
||||||
isExclusivelyOwnedByReturnChain =
|
isExclusivelyOwnedByReturnChain =
|
||||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|
||||||
|| isReturnHelperChainOp(onlyUser);
|
|| isReturnHelperChainOp(onlyUser);
|
||||||
}
|
}
|
||||||
if (!isExclusivelyOwnedByReturnChain)
|
if (!isExclusivelyOwnedByReturnChain)
|
||||||
@@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (!computeOp.getInputs().empty())
|
if (!computeOp.getInputs().empty())
|
||||||
for (Value input : computeOp.getInputs())
|
for (Value input : computeOp.getInputs())
|
||||||
|
|||||||
@@ -25,9 +25,11 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Common/IR/ShapeUtils.hpp"
|
||||||
#include "Common/IR/ConstantUtils.hpp"
|
#include "Common/IR/ConstantUtils.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Common/Support/CheckedArithmetic.hpp"
|
#include "Common/Support/CheckedArithmetic.hpp"
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||||
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isHostBackedMemRefValue(Value value) {
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return isa<memref::GetGlobalOp>(definingOp);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isHostBackedTensorValue(Value value) {
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return false;
|
||||||
|
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
|
||||||
|
extractSliceOp.getMixedOffsets(),
|
||||||
|
extractSliceOp.getStaticSizes(),
|
||||||
|
extractSliceOp.getStaticStrides())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
|
||||||
|
return isHostBackedMemRefValue(toTensorOp.getBuffer());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<Value>
|
static FailureOr<Value>
|
||||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||||
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
|||||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
if (isHostBackedTensorValue(vector)) {
|
||||||
|
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
func::FuncOp funcOp = *entryFunc;
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
|
||||||
|
funcOp.emitOpError(
|
||||||
|
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
OperationFolder constantFolder(&getContext());
|
OperationFolder constantFolder(&getContext());
|
||||||
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||||
markOpToRemove(computeBatchOp);
|
markOpToRemove(computeBatchOp);
|
||||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
||||||
continue;
|
continue;
|
||||||
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
||||||
|
|||||||
@@ -41,8 +41,11 @@ private:
|
|||||||
|
|
||||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
mlir::IRRewriter& rewriter,
|
||||||
|
mlir::OperationFolder& constantFolder);
|
||||||
|
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
enum class ReturnPathLoweringResult {
|
enum class ReturnPathLoweringResult {
|
||||||
Handled,
|
Handled,
|
||||||
@@ -51,7 +54,7 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
|
||||||
mlir::OpResult result,
|
mlir::OpResult result,
|
||||||
mlir::Value yieldValue,
|
mlir::Value yieldValue,
|
||||||
mlir::IRRewriter& rewriter);
|
mlir::IRRewriter& rewriter);
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ using namespace bufferization;
|
|||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue,
|
||||||
|
Location loc,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const StaticValueKnowledge& knowledge) {
|
||||||
bool isContiguous =
|
bool isContiguous =
|
||||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
succeeded(resolveContiguousAddress(memrefValue, knowledge)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
if (isContiguous && isDeviceLocalPimAddress(memrefValue, knowledge))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
@@ -32,7 +35,7 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
|||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (isHostBackedPimAddress(memrefValue)) {
|
if (isHostBackedPimAddress(memrefValue, knowledge)) {
|
||||||
return PimMemCopyHostToDevOp::create(
|
return PimMemCopyHostToDevOp::create(
|
||||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
|
|||||||
@@ -3,10 +3,15 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Value>
|
llvm::FailureOr<mlir::Value>
|
||||||
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
materializeContiguousInputMemRef(mlir::Value memrefValue,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::RewriterBase& rewriter,
|
||||||
|
const onnx_mlir::StaticValueKnowledge& knowledge = {});
|
||||||
mlir::Value
|
mlir::Value
|
||||||
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,26 @@ using namespace bufferization;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
|
static StaticValueKnowledge getEnclosingBufferizationKnowledge(Operation* op) {
|
||||||
|
StaticValueKnowledge knowledge;
|
||||||
|
|
||||||
|
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||||
|
knowledge.indexValues[coreBatchOp.getLaneArgument()] = 0;
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||||
|
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||||
|
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||||
|
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||||
|
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
struct MemCopyHostToDevOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -148,7 +168,8 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguous =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguous))
|
if (failed(contiguous))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(*contiguous);
|
inputs.push_back(*contiguous);
|
||||||
@@ -182,7 +203,8 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
|||||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -410,7 +432,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -456,7 +479,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -497,10 +521,12 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
auto contiguousLhs =
|
||||||
|
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousLhs))
|
if (failed(contiguousLhs))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
auto contiguousRhs =
|
||||||
|
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousRhs))
|
if (failed(contiguousRhs))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -534,10 +560,12 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
auto contiguousLhs =
|
||||||
|
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousLhs))
|
if (failed(contiguousLhs))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
auto contiguousRhs =
|
||||||
|
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousRhs))
|
if (failed(contiguousRhs))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -574,7 +602,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|||||||
@@ -116,6 +116,36 @@ lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyHostToDevOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getHostSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getDeviceTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getHostSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceTarget(), knowledge);
|
||||||
|
if (!sourceIsHost || !targetIsDevice || targetIsHost || sourceIsDevice)
|
||||||
|
return copyOp.emitOpError("pim.memcp_hd requires a host-backed source and a device-local target");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyDevToHostOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getDeviceSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getHostTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getHostTarget(), knowledge);
|
||||||
|
if (!targetIsHost || !sourceIsDevice || sourceIsHost || targetIsDevice)
|
||||||
|
return copyOp.emitOpError("pim.memcp_dh requires a device-local source and a host-backed target");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
if (!sourceIsDevice || !targetIsDevice || sourceIsHost || targetIsHost)
|
||||||
|
return copyOp.emitOpError("pim.memcp requires device-local source and target operands");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||||
@@ -129,6 +159,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||||
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||||
|
LogicalResult verifyPimCopyAddressSpaces(ModuleOp moduleOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
||||||
@@ -240,6 +271,10 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (failed(verifyPimCopyAddressSpaces(moduleOp))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
@@ -346,6 +381,31 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimBufferizationPass::verifyPimCopyAddressSpaces(ModuleOp moduleOp) const {
|
||||||
|
bool hasFailure = false;
|
||||||
|
auto verifyWithKnowledge = [&](auto coreLikeOp, const StaticValueKnowledge& initialKnowledge) {
|
||||||
|
(void) walkPimCoreBlockStructurally(
|
||||||
|
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(&op); copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(&op);
|
||||||
|
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(&op);
|
||||||
|
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
moduleOp.walk([&](pim::PimCoreOp coreOp) { verifyWithKnowledge(coreOp, seedCoreKnowledge(coreOp)); });
|
||||||
|
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, 0);
|
||||||
|
verifyWithKnowledge(coreBatchOp, knowledge);
|
||||||
|
});
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -96,8 +96,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||||
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
if (!mapOp->getParentOfType<pim::PimCoreOp>() && !mapOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
if (!coreOp)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ add_pim_library(OMPimVerification
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
|
OMPimCompilerOptions
|
||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
PimOps
|
PimOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
|
|||||||
@@ -5,12 +5,17 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||||
@@ -143,6 +148,479 @@ static bool isHostAddressableValue(Value value, const StaticValueKnowledge& know
|
|||||||
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
enum class CommunicationEventKind { Send, Receive };
|
||||||
|
|
||||||
|
struct CommunicationEvent {
|
||||||
|
CommunicationEventKind kind = CommunicationEventKind::Send;
|
||||||
|
int64_t coreId = 0;
|
||||||
|
int64_t peerCoreId = 0;
|
||||||
|
int64_t size = 0;
|
||||||
|
uint64_t ordinal = 0;
|
||||||
|
std::optional<int64_t> minChannelId;
|
||||||
|
std::string materializer;
|
||||||
|
std::optional<int64_t> traceId;
|
||||||
|
std::optional<int64_t> commOrder;
|
||||||
|
std::optional<int64_t> traceClassId;
|
||||||
|
std::optional<int64_t> traceBlockOrdinal;
|
||||||
|
std::string traceKind;
|
||||||
|
std::string tracePhase;
|
||||||
|
std::string traceClassKind;
|
||||||
|
std::string tracePayload;
|
||||||
|
std::string traceMessages;
|
||||||
|
std::string tracePrevOp;
|
||||||
|
std::string traceNextOp;
|
||||||
|
Operation* op = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CommunicationEventVector = SmallVector<CommunicationEvent, 0>;
|
||||||
|
|
||||||
|
static StringRef getCommunicationEventKindName(CommunicationEventKind kind) {
|
||||||
|
return kind == CommunicationEventKind::Send ? "send" : "receive";
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr StringLiteral kRaptorMinChannelIdAttr = "raptor.min_channel_id";
|
||||||
|
constexpr StringLiteral kRaptorMaterializerAttr = "raptor.materializer";
|
||||||
|
constexpr StringLiteral kRaptorCommOrderAttr = "raptor.comm_order";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceIdAttr = "raptor.comm_trace_id";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceKindAttr = "raptor.comm_trace_kind";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op";
|
||||||
|
|
||||||
|
static std::optional<int64_t> getNearestIntegerAttr(Operation* op, StringRef name) {
|
||||||
|
for (Operation* current = op; current; current = current->getParentOp())
|
||||||
|
if (auto attr = current->getAttrOfType<IntegerAttr>(name))
|
||||||
|
return attr.getInt();
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string getNearestStringAttr(Operation* op, StringRef name) {
|
||||||
|
for (Operation* current = op; current; current = current->getParentOp())
|
||||||
|
if (auto attr = current->getAttrOfType<StringAttr>(name))
|
||||||
|
return attr.getValue().str();
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatLocation(Location loc) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
loc.print(os);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatOperationSummary(Operation* op) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
OpPrintingFlags flags;
|
||||||
|
flags.skipRegions();
|
||||||
|
flags.elideLargeElementsAttrs();
|
||||||
|
op->print(os, flags);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatCommunicationEvent(const CommunicationEvent& event) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
os << "core " << event.coreId << " " << getCommunicationEventKindName(event.kind) << " "
|
||||||
|
<< (event.kind == CommunicationEventKind::Send ? "to" : "from") << " " << event.peerCoreId
|
||||||
|
<< " size " << event.size << "B ordinal " << event.ordinal;
|
||||||
|
if (event.minChannelId)
|
||||||
|
os << " min_channel " << *event.minChannelId;
|
||||||
|
if (event.commOrder)
|
||||||
|
os << " comm_order " << *event.commOrder;
|
||||||
|
if (!event.materializer.empty())
|
||||||
|
os << " materializer " << event.materializer;
|
||||||
|
if (event.traceId)
|
||||||
|
os << " trace#" << *event.traceId;
|
||||||
|
if (!event.tracePhase.empty())
|
||||||
|
os << " phase " << event.tracePhase;
|
||||||
|
if (event.traceClassId)
|
||||||
|
os << " class " << event.traceClassKind << "#" << *event.traceClassId;
|
||||||
|
if (event.traceBlockOrdinal)
|
||||||
|
os << " block_ordinal " << *event.traceBlockOrdinal;
|
||||||
|
if (!event.tracePayload.empty())
|
||||||
|
os << " payload " << event.tracePayload;
|
||||||
|
if (!event.traceMessages.empty())
|
||||||
|
os << " messages {" << event.traceMessages << "}";
|
||||||
|
if (!event.tracePrevOp.empty() || !event.traceNextOp.empty())
|
||||||
|
os << " inserted_between [" << event.tracePrevOp << " | " << event.traceNextOp << "]";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool areMatchedCommunicationEvents(const CommunicationEvent& lhs, const CommunicationEvent& rhs) {
|
||||||
|
if (lhs.coreId != rhs.peerCoreId || lhs.peerCoreId != rhs.coreId || lhs.size != rhs.size)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return (lhs.kind == CommunicationEventKind::Send && rhs.kind == CommunicationEventKind::Receive)
|
||||||
|
|| (lhs.kind == CommunicationEventKind::Receive && rhs.kind == CommunicationEventKind::Send);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static std::optional<size_t> findMatchingCounterpartIndex(const CommunicationEventVector& events,
|
||||||
|
const CommunicationEvent& event,
|
||||||
|
size_t begin) {
|
||||||
|
for (size_t index = begin; index < events.size(); ++index)
|
||||||
|
if (areMatchedCommunicationEvents(event, events[index]))
|
||||||
|
return index;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCounterpartProbe(llvm::raw_ostream& os,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
const CommunicationEvent& blockedEvent) {
|
||||||
|
auto peerEventsIt = coreEvents.find(blockedEvent.peerCoreId);
|
||||||
|
if (peerEventsIt == coreEvents.end()) {
|
||||||
|
os << " no local stream was collected for peer core " << blockedEvent.peerCoreId << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CommunicationEventVector& peerEvents = peerEventsIt->second;
|
||||||
|
size_t peerPc = 0;
|
||||||
|
auto peerPcIt = programCounters.find(blockedEvent.peerCoreId);
|
||||||
|
if (peerPcIt != programCounters.end())
|
||||||
|
peerPc = peerPcIt->second;
|
||||||
|
|
||||||
|
os << " counterpart probe for " << formatCommunicationEvent(blockedEvent) << "\n";
|
||||||
|
os << " peer core " << blockedEvent.peerCoreId << " current pc " << peerPc << " of " << peerEvents.size()
|
||||||
|
<< "\n";
|
||||||
|
|
||||||
|
std::optional<size_t> nextMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, peerPc);
|
||||||
|
std::optional<size_t> anyMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, 0);
|
||||||
|
|
||||||
|
if (!nextMatch && !anyMatch) {
|
||||||
|
os << " no matching counterpart exists anywhere in the peer stream\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nextMatch && anyMatch) {
|
||||||
|
os << " matching counterpart exists only before the peer pc at ordinal " << *anyMatch
|
||||||
|
<< "; this usually means the static stream expansion or ordering metadata is inconsistent\n";
|
||||||
|
os << " " << formatCommunicationEvent(peerEvents[*anyMatch]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(peerEvents[*anyMatch].op) << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CommunicationEvent& match = peerEvents[*nextMatch];
|
||||||
|
os << " next matching counterpart is at peer ordinal " << *nextMatch << " (distance +"
|
||||||
|
<< (*nextMatch >= peerPc ? *nextMatch - peerPc : 0) << ")\n";
|
||||||
|
os << " " << formatCommunicationEvent(match) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(match.op) << "\n";
|
||||||
|
|
||||||
|
if (*nextMatch == peerPc)
|
||||||
|
return;
|
||||||
|
|
||||||
|
os << " peer operations blocking before that counterpart:\n";
|
||||||
|
size_t end = std::min(peerEvents.size(), std::min(*nextMatch + static_cast<size_t>(1), peerPc + static_cast<size_t>(12)));
|
||||||
|
for (size_t index = peerPc; index < end; ++index) {
|
||||||
|
os << (index == peerPc ? " pc => " : " ") << "#" << index << " "
|
||||||
|
<< formatCommunicationEvent(peerEvents[index]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(peerEvents[index].op) << "\n";
|
||||||
|
}
|
||||||
|
if (end <= *nextMatch)
|
||||||
|
os << " ... " << (*nextMatch - end + 1) << " more peer communication event(s) before the counterpart\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static CommunicationEvent makeCommunicationEvent(CommunicationEventKind kind,
|
||||||
|
int64_t coreId,
|
||||||
|
int64_t peerCoreId,
|
||||||
|
int64_t size,
|
||||||
|
uint64_t ordinal,
|
||||||
|
Operation* op) {
|
||||||
|
return CommunicationEvent {kind,
|
||||||
|
coreId,
|
||||||
|
peerCoreId,
|
||||||
|
size,
|
||||||
|
ordinal,
|
||||||
|
getNearestIntegerAttr(op, kRaptorMinChannelIdAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorMaterializerAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceIdAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommOrderAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceClassIdAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceBlockOrdinalAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceKindAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePhaseAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceClassKindAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePayloadAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceMessagesAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePrevOpAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceNextOpAttr),
|
||||||
|
op};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult appendCoreCommunicationEvents(Block& block,
|
||||||
|
int64_t coreId,
|
||||||
|
const StaticValueKnowledge& initialKnowledge,
|
||||||
|
SmallVectorImpl<CommunicationEvent>& events,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
return walkPimCoreBlock(block, initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
|
if (auto sendOp = dyn_cast<pim::PimSendOp>(&op)) {
|
||||||
|
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||||
|
if (failed(targetCoreId)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("cannot statically resolve send target core for PIM communication deadlock check");
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(makeCommunicationEvent(CommunicationEventKind::Send,
|
||||||
|
coreId,
|
||||||
|
*targetCoreId,
|
||||||
|
sendOp.getSize(),
|
||||||
|
static_cast<uint64_t>(events.size()),
|
||||||
|
&op));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(&op)) {
|
||||||
|
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||||
|
if (failed(sourceCoreId)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("cannot statically resolve receive source core for PIM communication deadlock check");
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(makeCommunicationEvent(CommunicationEventKind::Receive,
|
||||||
|
coreId,
|
||||||
|
*sourceCoreId,
|
||||||
|
receiveOp.getSize(),
|
||||||
|
static_cast<uint64_t>(events.size()),
|
||||||
|
&op));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCommunicationWindow(llvm::raw_ostream& os,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
int64_t coreId,
|
||||||
|
size_t pc,
|
||||||
|
unsigned radius = 4) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const CommunicationEventVector& events = eventsIt->second;
|
||||||
|
size_t begin = pc > radius ? pc - radius : 0;
|
||||||
|
size_t end = std::min(events.size(), pc + static_cast<size_t>(radius) + 1);
|
||||||
|
os << " local stream for core " << coreId << " around pc " << pc << " of " << events.size() << ":\n";
|
||||||
|
for (size_t index = begin; index < end; ++index) {
|
||||||
|
os << (index == pc ? " => " : " ") << formatCommunicationEvent(events[index]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(events[index].op) << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCommunicationDeadlockReport(const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
ArrayRef<int64_t> cycle) {
|
||||||
|
llvm::errs() << "\n=== PIM static communication deadlock report ===\n";
|
||||||
|
llvm::errs() << "wait cycle:";
|
||||||
|
for (int64_t coreId : cycle)
|
||||||
|
llvm::errs() << " " << coreId;
|
||||||
|
if (!cycle.empty())
|
||||||
|
llvm::errs() << " -> " << cycle.front();
|
||||||
|
llvm::errs() << "\n\nblocked heads:\n";
|
||||||
|
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
const CommunicationEvent& event = eventsIt->second[pcIt->second];
|
||||||
|
llvm::errs() << " " << formatCommunicationEvent(event) << "\n";
|
||||||
|
llvm::errs() << " loc: " << formatLocation(event.op->getLoc()) << "\n";
|
||||||
|
llvm::errs() << " op : " << formatOperationSummary(event.op) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "\npeer counterpart probes:\n";
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
printCounterpartProbe(llvm::errs(), coreEvents, programCounters, eventsIt->second[pcIt->second]);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "\nlocal communication streams:\n";
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (pcIt == programCounters.end())
|
||||||
|
continue;
|
||||||
|
printCommunicationWindow(llvm::errs(), coreEvents, coreId, pcIt->second);
|
||||||
|
}
|
||||||
|
llvm::errs() << "=== end PIM static communication deadlock report ===\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void emitCommunicationDeadlockCycle(ModuleOp moduleOp,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
ArrayRef<int64_t> cycle) {
|
||||||
|
printCommunicationDeadlockReport(coreEvents, programCounters, cycle);
|
||||||
|
|
||||||
|
auto diagnostic = moduleOp.emitError()
|
||||||
|
<< "PIM communication deadlock check found a blocking send/receive cycle while statically simulating the "
|
||||||
|
"expanded per-core communication streams; see the PIM static communication deadlock report above";
|
||||||
|
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& event = eventsIt->second[pcIt->second];
|
||||||
|
Diagnostic& note = diagnostic.attachNote(event.op->getLoc());
|
||||||
|
note << formatCommunicationEvent(event);
|
||||||
|
if (!event.materializer.empty())
|
||||||
|
note << " emitted by " << event.materializer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>> findCommunicationWaitCycle(
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters) {
|
||||||
|
for (const auto& [startCoreId, events] : coreEvents) {
|
||||||
|
auto startPcIt = programCounters.find(startCoreId);
|
||||||
|
if (startPcIt == programCounters.end() || startPcIt->second >= events.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
DenseMap<int64_t, size_t> positionInPath;
|
||||||
|
SmallVector<int64_t, 8> path;
|
||||||
|
int64_t currentCoreId = startCoreId;
|
||||||
|
while (true) {
|
||||||
|
auto eventsIt = coreEvents.find(currentCoreId);
|
||||||
|
auto pcIt = programCounters.find(currentCoreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto positionIt = positionInPath.find(currentCoreId);
|
||||||
|
if (positionIt != positionInPath.end()) {
|
||||||
|
SmallVector<int64_t> cycle;
|
||||||
|
for (size_t index = positionIt->second; index < path.size(); ++index)
|
||||||
|
cycle.push_back(path[index]);
|
||||||
|
return cycle;
|
||||||
|
}
|
||||||
|
|
||||||
|
positionInPath[currentCoreId] = path.size();
|
||||||
|
path.push_back(currentCoreId);
|
||||||
|
currentCoreId = eventsIt->second[pcIt->second].peerCoreId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyNoStaticCommunicationDeadlock(ModuleOp moduleOp,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
DenseMap<int64_t, CommunicationEventVector> coreEvents;
|
||||||
|
bool hasFailure = false;
|
||||||
|
|
||||||
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||||
|
if (funcOp.isExternal())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||||
|
int64_t coreId = coreOp.getCoreId();
|
||||||
|
if (failed(appendCoreCommunicationEvents(
|
||||||
|
coreOp.getBody().front(), coreId, StaticValueKnowledge {}, coreEvents[coreId], diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||||
|
SmallVector<int32_t> coreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||||
|
for (size_t lane = 0; lane < laneCount; ++lane) {
|
||||||
|
StaticValueKnowledge laneKnowledge;
|
||||||
|
laneKnowledge.indexValues[coreBatchOp.getLaneArgument()] = static_cast<int64_t>(lane);
|
||||||
|
for (unsigned inputIndex = 0; inputIndex < coreBatchOp.getInputs().size(); ++inputIndex)
|
||||||
|
laneKnowledge.aliases[coreBatchOp.getInputArgument(inputIndex)] = coreBatchOp.getInputs()[inputIndex];
|
||||||
|
|
||||||
|
SmallVector<int32_t> laneCoreIds = getLaneChunkCoreIds(coreIds, laneCount, static_cast<unsigned>(lane));
|
||||||
|
for (int32_t coreId : laneCoreIds) {
|
||||||
|
if (failed(appendCoreCommunicationEvents(
|
||||||
|
coreBatchOp.getBody().front(), coreId, laneKnowledge, coreEvents[coreId], diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasFailure)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseMap<int64_t, size_t> programCounters;
|
||||||
|
for (const auto& [coreId, events] : coreEvents)
|
||||||
|
programCounters[coreId] = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
bool madeProgress = false;
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
size_t pc = programCounters[coreId];
|
||||||
|
if (pc >= events.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& event = events[pc];
|
||||||
|
auto peerEventsIt = coreEvents.find(event.peerCoreId);
|
||||||
|
if (peerEventsIt == coreEvents.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t peerPc = programCounters[event.peerCoreId];
|
||||||
|
if (peerPc >= peerEventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& peerEvent = peerEventsIt->second[peerPc];
|
||||||
|
if (!areMatchedCommunicationEvents(event, peerEvent))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
++programCounters[coreId];
|
||||||
|
++programCounters[event.peerCoreId];
|
||||||
|
madeProgress = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (madeProgress)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
bool allDone = true;
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
if (programCounters[coreId] < events.size()) {
|
||||||
|
allDone = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allDone)
|
||||||
|
return success();
|
||||||
|
|
||||||
|
auto cycle = findCommunicationWaitCycle(coreEvents, programCounters);
|
||||||
|
if (succeeded(cycle)) {
|
||||||
|
emitCommunicationDeadlockCycle(moduleOp, coreEvents, programCounters, *cycle);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto diagnostic = moduleOp.emitError()
|
||||||
|
<< "PIM communication deadlock check stalled without finding a closed wait cycle; this usually means a "
|
||||||
|
"send/receive peer is missing or ordered after a finished core";
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
size_t pc = programCounters[coreId];
|
||||||
|
if (pc >= events.size())
|
||||||
|
continue;
|
||||||
|
const CommunicationEvent& event = events[pc];
|
||||||
|
diagnostic.attachNote(event.op->getLoc()) << formatCommunicationEvent(event);
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||||
|
|
||||||
@@ -212,11 +690,18 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasFailure = false;
|
||||||
|
if (pimDetectCommunicationDeadlock && failed(verifyNoStaticCommunicationDeadlock(moduleOp, diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
|
||||||
if (diagnostics.hasFailure()) {
|
if (diagnostics.hasFailure()) {
|
||||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||||
signalPassFailure();
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hasFailure)
|
||||||
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def SpatTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatCompute : SpatOp<"compute",
|
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||||
[SingleBlock, AttrSizedOperandSegments,
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Compute region with attached constant weights";
|
let summary = "Compute region with attached constant weights";
|
||||||
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
|
|||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasFolder = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||||
[SingleBlock, AttrSizedOperandSegments,
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||||
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
let hasVerifier = 1;
|
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
|
||||||
let hasCustomAssemblyFormat = 1;
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||||
Pure,
|
Pure,
|
||||||
Terminator,
|
Terminator,
|
||||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||||
HasParent<"SpatComputeBatch">,
|
|
||||||
] # GraphRegionNoTerminator.traits> {
|
] # GraphRegionNoTerminator.traits> {
|
||||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
let summary = "Parallel combining terminator for resultful Spatial compute batches";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$region);
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
@@ -159,6 +192,82 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Planning
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
|
||||||
|
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
SpatTensor:$weight,
|
||||||
|
Optional<SpatTensor>:$bias,
|
||||||
|
DenseI64ArrayAttr:$pads,
|
||||||
|
DenseI64ArrayAttr:$strides,
|
||||||
|
DenseI64ArrayAttr:$dilations,
|
||||||
|
I64Attr:$group,
|
||||||
|
StrAttr:$logicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
|
||||||
|
let summary = "Layout-aware ReLU planning op";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
StrAttr:$logicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
|
||||||
|
let summary = "Passive logical-to-physical layout selection record";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
StrAttr:$logicalLayout,
|
||||||
|
StrAttr:$physicalLayout,
|
||||||
|
DenseI64ArrayAttr:$fragmentOffsets,
|
||||||
|
DenseI64ArrayAttr:$fragmentSizes,
|
||||||
|
StrAttr:$indexMap
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
|
||||||
|
let summary = "Explicit layout conversion or materialization barrier";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
StrAttr:$logicalLayout,
|
||||||
|
StrAttr:$sourcePhysicalLayout,
|
||||||
|
StrAttr:$targetPhysicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Communication
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
|
||||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
|
||||||
|
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
|
||||||
|
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
}
|
}
|
||||||
|
|
||||||
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
||||||
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
|
|||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
template <typename ComputeOpTy>
|
||||||
|
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
return getBlockArgument(compute.getBody(), idx);
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), getWeights().size() + idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
template <typename ComputeOpTy>
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
|
||||||
return {
|
}
|
||||||
{*existing, *getWeightArgument(index)}
|
|
||||||
};
|
template <typename ComputeOpTy>
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
|
||||||
|
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
|
||||||
|
auto index = std::distance(compute.getWeights().begin(), existing);
|
||||||
|
return {{*existing, *getComputeWeightArgument(compute, index)}};
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = compute.getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = compute.getInputs().size();
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
compute.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
|
auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
template <typename ComputeBatchOpTy>
|
||||||
unsigned weightCount = getWeights().size();
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
unsigned inputCount = getInputs().size();
|
insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
|
||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
|
||||||
|
auto index = std::distance(batch.getWeights().begin(), existing);
|
||||||
|
return {{*existing, *batch.getWeightArgument(index)}};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned weightCount = batch.getWeights().size();
|
||||||
|
unsigned inputCount = batch.getInputs().size();
|
||||||
|
batch.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
|
|
||||||
|
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
template <typename ComputeOpTy>
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
|
||||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
unsigned weightCount = compute.getWeights().size();
|
||||||
if (idx > getNumResults())
|
unsigned inputCount = compute.getInputs().size();
|
||||||
return failure();
|
compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
rewriter.setInsertionPoint(getOperation());
|
compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
|
||||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
if (!blockArg)
|
||||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
return std::nullopt;
|
||||||
newCompute->setAttrs((*this)->getAttrs());
|
return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
|
||||||
static_cast<int32_t>(newCompute.getWeights().size()),
|
|
||||||
static_cast<int32_t>(newCompute.getInputs().size()));
|
|
||||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
|
||||||
getResult(oldResultIdx)
|
|
||||||
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
|
||||||
rewriter.eraseOp(getOperation());
|
|
||||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
template <typename ComputeOpTy>
|
||||||
|
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
if (region.empty())
|
if (region.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
for (unsigned index = 0; index < compute.getWeights().size(); ++index)
|
||||||
if (auto weightArg = getWeightArgument(index))
|
if (auto weightArg = compute.getWeightArgument(index))
|
||||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||||
|
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
for (unsigned index = 0; index < compute.getInputs().size(); ++index)
|
||||||
if (auto inputArg = getInputArgument(index))
|
if (auto inputArg = compute.getInputArgument(index))
|
||||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
template <typename ComputeOpTy>
|
||||||
|
FailureOr<std::tuple<OpResult, ComputeOpTy>>
|
||||||
|
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > compute.getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
rewriter.setInsertionPoint(compute.getOperation());
|
||||||
|
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newCompute =
|
||||||
|
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
|
||||||
|
newCompute->setAttrs(compute->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||||
|
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
|
||||||
|
compute.getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(compute.getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
|
||||||
|
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > batch.getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(batch.getOperation());
|
||||||
|
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newBatch =
|
||||||
|
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
|
||||||
|
newBatch->setAttrs(batch->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||||
|
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||||
|
if (newBatch.getBody().empty()) {
|
||||||
|
rewriter.eraseOp(newBatch);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||||
|
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
|
||||||
|
batch.getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(batch.getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isAnySpatialComputeLike(Operation* op) {
|
||||||
|
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
|
||||||
|
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
return insertComputeInput(*this, idx, input, loc);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
|
||||||
|
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
|
||||||
|
return getComputeWeightArgument(*this, idx);
|
||||||
|
}
|
||||||
|
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
return insertComputeInput(*this, idx, input, loc);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
|
||||||
|
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
return getBlockArgument(getBody(), 1 + idx);
|
return getBlockArgument(getBody(), 1 + idx);
|
||||||
}
|
}
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
}
|
}
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>>
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
|
||||||
return {
|
|
||||||
{*existing, *getWeightArgument(index)}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned weightCount = getWeights().size();
|
|
||||||
unsigned inputCount = getInputs().size();
|
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
|
||||||
setComputeOperandSegmentSizes(
|
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
|
||||||
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
|
|
||||||
if (!blockArg)
|
|
||||||
return std::nullopt;
|
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
|
||||||
}
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = getInputs().size();
|
||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
|
||||||
|
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
|
||||||
if (idx > getNumResults())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(getOperation());
|
|
||||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
|
||||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
|
||||||
auto newBatch =
|
|
||||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
|
||||||
newBatch->setAttrs((*this)->getAttrs());
|
|
||||||
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
|
||||||
static_cast<int32_t>(newBatch.getWeights().size()),
|
|
||||||
static_cast<int32_t>(newBatch.getInputs().size()));
|
|
||||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
|
||||||
if (newBatch.getBody().empty()) {
|
|
||||||
rewriter.eraseOp(newBatch);
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
|
||||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
|
||||||
getResult(oldResultIdx)
|
|
||||||
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
|
||||||
rewriter.eraseOp(getOperation());
|
|
||||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
|
||||||
}
|
}
|
||||||
|
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
|
||||||
if (region.empty())
|
if (region.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (auto laneArg = getLaneArgument())
|
if (auto laneArg = getLaneArgument())
|
||||||
setNameFn(*laneArg, "lane");
|
setNameFn(*laneArg, "lane");
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
|
auto outputArg = getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
continue;
|
||||||
|
if (index == 0) {
|
||||||
|
setNameFn(*outputArg, "out");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||||
if (auto weightArg = getWeightArgument(index))
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
return getBlockArgument(getBody(), 1 + idx);
|
||||||
|
}
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
|
||||||
if (auto inputArg = getInputArgument(index))
|
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
}
|
||||||
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
|
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
|
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
|
||||||
|
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (auto laneArg = getLaneArgument())
|
||||||
|
setNameFn(*laneArg, "lane");
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
auto outputArg = getOutputArgument(index);
|
auto outputArg = getOutputArgument(index);
|
||||||
if (!outputArg)
|
if (!outputArg)
|
||||||
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
|||||||
builder.createBlock(bodyRegion);
|
builder.createBlock(bodyRegion);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
OpResult SpatInParallelOp::getParentResult(int64_t idx) {
|
||||||
|
Operation* parent = getOperation()->getParentOp();
|
||||||
|
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
|
||||||
|
return parent->getResult(idx);
|
||||||
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||||
|
|
||||||
|
|||||||
@@ -26,3 +26,19 @@
|
|||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
bool isGraphComputeLike(mlir::Operation* op);
|
||||||
|
bool isGraphBatchComputeLike(mlir::Operation* op);
|
||||||
|
bool isScheduledComputeLike(mlir::Operation* op);
|
||||||
|
bool isScheduledBatchComputeLike(mlir::Operation* op);
|
||||||
|
bool isAnySpatialComputeLike(mlir::Operation* op);
|
||||||
|
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
|
||||||
|
|
||||||
|
using SpatCompute = SpatGraphCompute;
|
||||||
|
using SpatComputeBatch = SpatGraphComputeBatch;
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(op.getWeights().size());
|
||||||
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||||
|
auto weightArg = op.getWeightArgument(index);
|
||||||
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(op.getInputs().size());
|
||||||
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||||
|
auto inputArg = op.getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||||
|
|
||||||
|
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
|
printer << " coreId " << coreIdAttr.getInt();
|
||||||
|
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
|
||||||
|
|
||||||
|
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||||
|
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||||
|
printer << " -> ";
|
||||||
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||||
|
printer << " ";
|
||||||
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
SmallVector<Type> weightTypes;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
|
int32_t coreId = 0;
|
||||||
|
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
|
if (hasCoreId && parser.parseInteger(coreId))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (weights.size() != weightTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
|
if (inputs.size() != inputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (inputArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreId cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
result.addAttribute(
|
||||||
|
"operandSegmentSizes",
|
||||||
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
|
if (hasCoreId)
|
||||||
|
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||||
|
|
||||||
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||||
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
applyArgumentTypes(inputTypes, inputArgs);
|
||||||
|
llvm::append_range(regionArgs, weightArgs);
|
||||||
|
llvm::append_range(regionArgs, inputArgs);
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
|
||||||
|
auto laneArg = op.getLaneArgument();
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(op.getWeights().size());
|
||||||
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||||
|
auto weightArg = op.getWeightArgument(index);
|
||||||
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(op.getInputs().size());
|
||||||
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||||
|
auto inputArg = op.getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<BlockArgument> outputArgs;
|
||||||
|
if (!laneArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
if (op.getNumResults() != 0) {
|
||||||
|
outputArgs.reserve(op.getNumResults());
|
||||||
|
for (unsigned index = 0; index < op.getNumResults(); ++index) {
|
||||||
|
auto outputArg = op.getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
outputArgs.push_back(*outputArg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printer << " ";
|
||||||
|
printer.printOperand(*laneArg);
|
||||||
|
printer << " = 0 to " << op.getLaneCount();
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||||
|
if (op.getNumResults() != 0) {
|
||||||
|
printer << " shared_outs";
|
||||||
|
printBlockArgumentList(printer, outputArgs);
|
||||||
|
}
|
||||||
|
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
|
||||||
|
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||||
|
printer << " coreIds ";
|
||||||
|
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||||
|
}
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
op->getAttrs(),
|
||||||
|
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||||
|
printer << " -> ";
|
||||||
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||||
|
printer << " ";
|
||||||
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
int64_t lowerBound = 0;
|
||||||
|
int32_t laneCount = 0;
|
||||||
|
OpAsmParser::Argument laneArg;
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
SmallVector<Type> weightTypes;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
|
||||||
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||||
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||||
|
return failure();
|
||||||
|
if (lowerBound != 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
|
return failure();
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||||
|
if (parseBlockArgumentList(parser, outputArgs))
|
||||||
|
return failure();
|
||||||
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||||
|
return failure();
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (weights.size() != weightTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
|
if (inputs.size() != inputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (inputArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (outputArgs.size() != outputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"number of shared output bindings and result types must match");
|
||||||
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreIds cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||||
|
result.addAttribute(
|
||||||
|
"operandSegmentSizes",
|
||||||
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
|
if (hasCoreIds)
|
||||||
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||||
|
|
||||||
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||||
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
applyBatchRegionArgumentTypes(
|
||||||
|
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||||
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||||
SmallVector<Value> weightArgs;
|
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
weightArgs.reserve(getWeights().size());
|
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
|
||||||
auto weightArg = getWeightArgument(index);
|
|
||||||
if (!weightArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
weightArgs.push_back(*weightArg);
|
|
||||||
}
|
|
||||||
SmallVector<Value> inputArgs;
|
|
||||||
inputArgs.reserve(getInputs().size());
|
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
|
||||||
auto inputArg = getInputArgument(index);
|
|
||||||
if (!inputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
inputArgs.push_back(*inputArg);
|
|
||||||
}
|
|
||||||
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
|
||||||
printer << " coreId " << coreIdAttr.getInt();
|
|
||||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
|
||||||
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
|
||||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
|
||||||
|
|
||||||
printer << " : ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
||||||
printer << " -> ";
|
|
||||||
printCompressedTypeSequence(printer, getResultTypes());
|
|
||||||
printer << " ";
|
|
||||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
||||||
}
|
}
|
||||||
|
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
||||||
SmallVector<Type> weightTypes;
|
|
||||||
SmallVector<Type> inputTypes;
|
|
||||||
SmallVector<Type> outputTypes;
|
|
||||||
int32_t crossbarWeightCount = 0;
|
|
||||||
int32_t coreId = 0;
|
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
|
||||||
if (hasCoreId && parser.parseInteger(coreId))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
|
||||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
|
||||||
return failure();
|
|
||||||
(void) crossbarWeightCount;
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
||||||
|| parseCompressedRepeatedList(
|
|
||||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parseCompressedRepeatedList(
|
|
||||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
||||||
if (weightArgs.size() != weights.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
||||||
if (inputs.size() != inputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
||||||
if (inputArgs.size() != inputs.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
||||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"coreId cannot be specified both positionally and in attr-dict");
|
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
|
||||||
result.addAttribute(
|
|
||||||
"operandSegmentSizes",
|
|
||||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
||||||
if (hasCoreId)
|
|
||||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
|
||||||
|
|
||||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputTypes);
|
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
|
||||||
applyArgumentTypes(weightTypes, weightArgs);
|
|
||||||
applyArgumentTypes(inputTypes, inputArgs);
|
|
||||||
llvm::append_range(regionArgs, weightArgs);
|
|
||||||
llvm::append_range(regionArgs, inputArgs);
|
|
||||||
return parser.parseRegion(*body, regionArgs);
|
|
||||||
}
|
}
|
||||||
|
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
auto laneArg = getLaneArgument();
|
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
|
||||||
SmallVector<Value> weightArgs;
|
|
||||||
weightArgs.reserve(getWeights().size());
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
|
||||||
auto weightArg = getWeightArgument(index);
|
|
||||||
if (!weightArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
weightArgs.push_back(*weightArg);
|
|
||||||
}
|
|
||||||
SmallVector<Value> inputArgs;
|
|
||||||
inputArgs.reserve(getInputs().size());
|
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
|
||||||
auto inputArg = getInputArgument(index);
|
|
||||||
if (!inputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
inputArgs.push_back(*inputArg);
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<BlockArgument> outputArgs;
|
|
||||||
if (!laneArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
if (getNumResults() != 0) {
|
|
||||||
outputArgs.reserve(getNumResults());
|
|
||||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
|
||||||
auto outputArg = getOutputArgument(index);
|
|
||||||
if (!outputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
outputArgs.push_back(*outputArg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
printer << " ";
|
|
||||||
printer.printOperand(*laneArg);
|
|
||||||
printer << " = 0 to " << getLaneCount();
|
|
||||||
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
|
||||||
|
|
||||||
if (getNumResults() != 0) {
|
|
||||||
printer << " shared_outs";
|
|
||||||
printBlockArgumentList(printer, outputArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
|
|
||||||
|
|
||||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
|
||||||
printer << " coreIds ";
|
|
||||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
|
||||||
}
|
|
||||||
|
|
||||||
printer.printOptionalAttrDict(
|
|
||||||
(*this)->getAttrs(),
|
|
||||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
|
||||||
|
|
||||||
printer << " : ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
||||||
printer << " -> ";
|
|
||||||
printCompressedTypeSequence(printer, getResultTypes());
|
|
||||||
printer << " ";
|
|
||||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
||||||
}
|
}
|
||||||
|
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
int64_t lowerBound = 0;
|
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
|
||||||
int32_t laneCount = 0;
|
|
||||||
OpAsmParser::Argument laneArg;
|
|
||||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
||||||
SmallVector<Type> weightTypes;
|
|
||||||
SmallVector<Type> inputTypes;
|
|
||||||
SmallVector<Type> outputTypes;
|
|
||||||
int32_t crossbarWeightCount = 0;
|
|
||||||
SmallVector<int32_t> coreIds;
|
|
||||||
|
|
||||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
|
||||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
|
||||||
return failure();
|
|
||||||
if (lowerBound != 0)
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
|
||||||
if (parseBlockArgumentList(parser, outputArgs))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
|
||||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
|
||||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
|
||||||
return failure();
|
|
||||||
(void) crossbarWeightCount;
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
||||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
|
||||||
|| parseCompressedRepeatedList(
|
|
||||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
||||||
if (weightArgs.size() != weights.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
||||||
if (inputs.size() != inputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
||||||
if (inputArgs.size() != inputs.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
||||||
if (outputArgs.size() != outputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"number of shared output bindings and result types must match");
|
|
||||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"coreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
|
||||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
|
||||||
result.addAttribute(
|
|
||||||
"operandSegmentSizes",
|
|
||||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
||||||
if (hasCoreIds)
|
|
||||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
|
||||||
|
|
||||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputTypes);
|
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
|
||||||
applyBatchRegionArgumentTypes(
|
|
||||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
|
||||||
return parser.parseRegion(*body, regionArgs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
template <typename ComputeOpTy>
|
||||||
Block& block = getBody().front();
|
LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
if (!llvm::hasSingleElement(block))
|
if (!llvm::hasSingleElement(block))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
|||||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||||
if (blockArg.getOwner() == &block) {
|
if (blockArg.getOwner() == &block) {
|
||||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
results.push_back(compute.getOperand(blockArg.getArgNumber()));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
return foldComputeLike(*this, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
return foldComputeLike(*this, results);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
|||||||
return shapedType.getShape();
|
return shapedType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
template <typename ComputeBatchOpTy>
|
||||||
|
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
|
||||||
if (batchOp.getNumResults() == 0)
|
if (batchOp.getNumResults() == 0)
|
||||||
return false;
|
return false;
|
||||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||||
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isStaticIndexExpr(Value value) {
|
||||||
|
if (matchConstantIndexValue(value))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
|
if (affineApply) {
|
||||||
|
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
|
||||||
|
return false;
|
||||||
|
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
|
||||||
|
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
|
||||||
|
|
||||||
|
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
|
||||||
|
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||||
if (value == laneArg || matchConstantIndexValue(value))
|
if (value == laneArg || isStaticIndexExpr(value))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||||
if (!addOp)
|
if (addOp)
|
||||||
|
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|
||||||
|
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
|
||||||
|
|
||||||
|
auto mulOp = value.getDefiningOp<arith::MulIOp>();
|
||||||
|
if (!mulOp)
|
||||||
return false;
|
return false;
|
||||||
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|
return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|
||||||
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
|
|| (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
|||||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
InFlightDiagnostic diagnostic =
|
||||||
<< kind << " body may only directly reference external constants";
|
ownerOp->emitOpError() << kind << " body may not capture external values";
|
||||||
diagnostic.attachNote(op->getLoc())
|
diagnostic.attachNote(op->getLoc())
|
||||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
<< "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
|
||||||
|
<< operand.getOperandNumber() << " type=" << value.getType()
|
||||||
|
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
|
||||||
|
: "scalar"));
|
||||||
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
|
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
|
||||||
|
else if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
|
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
|
||||||
|
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
|
||||||
|
<< blockArg.getOwner()->getParentOp()->getName() << "'";
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
template <typename ComputeBatchOpTy>
|
||||||
|
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
|
||||||
if (batchOp.getNumResults() == 0) {
|
if (batchOp.getNumResults() == 0) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
@@ -344,144 +380,266 @@ LogicalResult SpatConcatOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
|
||||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
|
||||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
static bool isKnownPhysicalLayout(StringRef layout) {
|
||||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
return layout == "dense_nchw" || layout == "nchw_row_strip";
|
||||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
}
|
||||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
|
||||||
});
|
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
|
||||||
})) {
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
return op->emitError("ComputeResult used directly inside another Compute");
|
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||||
|
if (!inputType || !outputType)
|
||||||
|
return op->emitOpError() << kind << " requires ranked tensor input and output types";
|
||||||
|
if (inputType.getElementType() != outputType.getElementType())
|
||||||
|
return op->emitOpError() << kind << " requires matching input/output element types";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatConv2DPlanOp::verify() {
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
|
||||||
|
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||||
|
if (!inputType || !weightType || !outputType)
|
||||||
|
return emitError("requires ranked tensor input, weight, and output");
|
||||||
|
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
|
||||||
|
return emitError("requires rank-4 input, weight, and output tensors");
|
||||||
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (getPads().size() != 4)
|
||||||
|
return emitError("requires exactly four pad values");
|
||||||
|
if (getStrides().size() != 2)
|
||||||
|
return emitError("requires exactly two stride values");
|
||||||
|
if (getDilations().size() != 2)
|
||||||
|
return emitError("requires exactly two dilation values");
|
||||||
|
if (getGroup() < 1)
|
||||||
|
return emitError("requires group >= 1");
|
||||||
|
if (inputType.getElementType() != weightType.getElementType()
|
||||||
|
|| inputType.getElementType() != outputType.getElementType()) {
|
||||||
|
return emitError("requires matching input, weight, and output element types");
|
||||||
|
}
|
||||||
|
if (getBias()) {
|
||||||
|
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
|
||||||
|
if (!biasType)
|
||||||
|
return emitError("requires ranked tensor bias type");
|
||||||
|
if (biasType.getElementType() != outputType.getElementType())
|
||||||
|
return emitError("requires bias element type to match output element type");
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatReluPlanOp::verify() {
|
||||||
auto& block = getBody().front();
|
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
|
||||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
return failure();
|
||||||
if (block.getNumArguments() != expectedArgCount)
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
return emitError("compute body must have weight and input block arguments");
|
return emitError("requires a known logical layout");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
LogicalResult SpatReconciliatorOp::verify() {
|
||||||
auto blockArg = getWeightArgument(weightIndex);
|
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
|
||||||
if (!blockArg || blockArg->getType() != weight.getType())
|
return failure();
|
||||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getPhysicalLayout()))
|
||||||
|
return emitError("requires a known physical layout");
|
||||||
|
|
||||||
|
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||||
|
if (!logicalType)
|
||||||
|
return emitError("requires ranked tensor output");
|
||||||
|
|
||||||
|
auto offsets = getFragmentOffsets();
|
||||||
|
auto sizes = getFragmentSizes();
|
||||||
|
if (offsets.size() != sizes.size())
|
||||||
|
return emitError("fragment offset and size arrays must have the same length");
|
||||||
|
if (offsets.empty())
|
||||||
|
return success();
|
||||||
|
|
||||||
|
int64_t rank = logicalType.getRank();
|
||||||
|
if (rank <= 0 || offsets.size() % rank != 0)
|
||||||
|
return emitError("fragment metadata must be a whole number of rank-sized fragments");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> shape = logicalType.getShape();
|
||||||
|
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
|
||||||
|
int64_t dim = index % rank;
|
||||||
|
int64_t offset = offsets[index];
|
||||||
|
int64_t size = sizes[index];
|
||||||
|
if (offset < 0 || size < 0)
|
||||||
|
return emitError("fragment offsets and sizes must be non-negative");
|
||||||
|
int64_t logicalDim = shape[dim];
|
||||||
|
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
|
||||||
|
return emitError("fragment bounds must stay within the logical tensor shape");
|
||||||
}
|
}
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
||||||
auto blockArg = getInputArgument(inputIndex);
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatMaterializeLayoutOp::verify() {
|
||||||
|
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
|
||||||
|
return failure();
|
||||||
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
|
||||||
|
return emitError("requires a known source physical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
|
||||||
|
return emitError("requires a known target physical layout");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||||
|
if (!isAnySpatialComputeLike(op))
|
||||||
|
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
|
||||||
|
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||||
|
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||||
|
return !isAnySpatialComputeLike(op->getParentOp());
|
||||||
|
});
|
||||||
|
})) {
|
||||||
|
return op->emitError("compute result used directly inside another Spatial compute body");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
|
||||||
|
auto& block = compute.getBody().front();
|
||||||
|
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
|
||||||
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
|
return compute.emitOpError("compute body must have weight and input block arguments");
|
||||||
|
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||||
|
auto blockArg = compute.getWeightArgument(weightIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
|
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
|
||||||
|
}
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto blockArg = compute.getInputArgument(inputIndex);
|
||||||
if (!blockArg || blockArg->getType() != input.getType())
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
return emitError("compute input block argument types must match input operand types exactly");
|
return compute.emitOpError("compute input block argument types must match input operand types exactly");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
return emitError("ComputeOp must have a single yield operation");
|
return compute.emitOpError("ComputeOp must have a single yield operation");
|
||||||
|
|
||||||
auto resultTypes = getResultTypes();
|
auto resultTypes = compute.getResultTypes();
|
||||||
auto yieldTypes = yieldOp->getOperandTypes();
|
auto yieldTypes = yieldOp->getOperandTypes();
|
||||||
if (resultTypes.size() != yieldTypes.size())
|
if (resultTypes.size() != yieldTypes.size())
|
||||||
return emitError("ComputeOp must have same number of results as yieldOp operands");
|
return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
|
||||||
|
|
||||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||||
auto resultType = std::get<0>(it);
|
auto resultType = std::get<0>(it);
|
||||||
auto yieldType = std::get<1>(it);
|
auto yieldType = std::get<1>(it);
|
||||||
|
|
||||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
||||||
return emitError("ComputeOp output must be of the same type as yieldOp operand");
|
return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
|
||||||
|
|
||||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
||||||
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
|
return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
||||||
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
|
||||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return compute.emitOpError("ComputeOp block argument is not used");
|
||||||
if (failed(verifyStaticWeights(*this, "compute")))
|
if (failed(verifyStaticWeights(compute, opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(compute.getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatComputeBatch::verify() {
|
LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
|
||||||
int32_t count = getLaneCount();
|
|
||||||
|
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
|
||||||
|
int32_t count = batch.getLaneCount();
|
||||||
if (count <= 0)
|
if (count <= 0)
|
||||||
return emitError("laneCount must be positive");
|
return batch.emitOpError("laneCount must be positive");
|
||||||
|
|
||||||
auto laneCountSz = static_cast<size_t>(count);
|
auto laneCountSz = static_cast<size_t>(count);
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
|
||||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||||
if (!coreIdsAttr)
|
if (!coreIdsAttr)
|
||||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
|
||||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||||
return emitError("compute_batch coreIds array length must match laneCount");
|
return batch.emitOpError("compute_batch coreIds array length must match laneCount");
|
||||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||||
return emitError("compute_batch coreIds values must be non-negative");
|
return batch.emitOpError("compute_batch coreIds values must be non-negative");
|
||||||
DenseSet<int32_t> seenCoreIds;
|
DenseSet<int32_t> seenCoreIds;
|
||||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||||
if (!seenCoreIds.insert(coreId).second)
|
if (!seenCoreIds.insert(coreId).second)
|
||||||
return emitError("compute_batch coreIds values must be unique");
|
return batch.emitOpError("compute_batch coreIds values must be unique");
|
||||||
}
|
}
|
||||||
|
|
||||||
Block& block = getBody().front();
|
Block& block = batch.getBody().front();
|
||||||
if (block.getNumArguments() == 0)
|
if (block.getNumArguments() == 0)
|
||||||
return emitError("compute_batch body must have exactly one lane block argument");
|
return batch.emitOpError("compute_batch body must have exactly one lane block argument");
|
||||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
|
||||||
if (block.getNumArguments() != expectedArgCount)
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||||
auto laneArg = getLaneArgument();
|
auto laneArg = batch.getLaneArgument();
|
||||||
if (!laneArg || !laneArg->getType().isIndex())
|
if (!laneArg || !laneArg->getType().isIndex())
|
||||||
return emitError("compute_batch first block argument must have index type");
|
return batch.emitOpError("compute_batch first block argument must have index type");
|
||||||
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
|
||||||
auto blockArg = getWeightArgument(weightIndex);
|
auto blockArg = batch.getWeightArgument(weightIndex);
|
||||||
if (!blockArg || blockArg->getType() != weight.getType())
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
|
||||||
}
|
}
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
|
||||||
auto blockArg = getInputArgument(inputIndex);
|
auto blockArg = batch.getInputArgument(inputIndex);
|
||||||
if (!blockArg || blockArg->getType() != input.getType())
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
|
||||||
}
|
}
|
||||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
|
||||||
auto blockArg = getOutputArgument(resultIndex);
|
auto blockArg = batch.getOutputArgument(resultIndex);
|
||||||
if (!blockArg || blockArg->getType() != resultType)
|
if (!blockArg || blockArg->getType() != resultType)
|
||||||
return emitError("compute_batch output block argument types must match result types exactly");
|
return batch.emitOpError("compute_batch output block argument types must match result types exactly");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(batch.getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyStaticWeights(*this, "compute_batch")))
|
if (failed(verifyStaticWeights(batch, opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
|
||||||
return failure();
|
return failure();
|
||||||
return verifyBatchBody(*this, block);
|
return verifyBatchBody(batch, block);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
|
||||||
|
|
||||||
|
LogicalResult SpatScheduledComputeBatch::verify() {
|
||||||
|
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatInParallelOp::verify() {
|
LogicalResult SpatInParallelOp::verify() {
|
||||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
Operation* parent = getOperation()->getParentOp();
|
||||||
if (!batchOp)
|
if (!isAnySpatialComputeBatchLike(parent))
|
||||||
return emitOpError("expected spat.compute_batch parent");
|
return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
|
||||||
if (batchOp.getNumResults() == 0)
|
if (parent->getNumResults() == 0)
|
||||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||||
|
|
||||||
auto laneArg = batchOp.getLaneArgument();
|
std::optional<BlockArgument> laneArg;
|
||||||
|
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
|
||||||
|
laneArg = graphBatch.getLaneArgument();
|
||||||
|
else
|
||||||
|
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
|
||||||
if (!laneArg)
|
if (!laneArg)
|
||||||
return emitOpError("expected compute_batch lane block argument");
|
return emitOpError("expected compute_batch lane block argument");
|
||||||
for (Operation& op : getRegion().front().getOperations()) {
|
for (Operation& op : getRegion().front().getOperations()) {
|
||||||
@@ -494,7 +652,10 @@ LogicalResult SpatInParallelOp::verify() {
|
|||||||
|
|
||||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||||
for (OpOperand& destination : destinations)
|
for (OpOperand& destination : destinations)
|
||||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
if ((isa<SpatGraphComputeBatch>(parent)
|
||||||
|
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|
||||||
|
|| (isa<SpatScheduledComputeBatch>(parent)
|
||||||
|
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
|
||||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+4645
-263
File diff suppressed because it is too large
Load Diff
@@ -40,10 +40,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
using SpatCompute = spatial::SpatCompute;
|
using SpatCompute = spatial::SpatGraphCompute;
|
||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
|
||||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||||
if (failed(checkedCoreId))
|
if (failed(checkedCoreId))
|
||||||
@@ -209,7 +209,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (Operation& op : funcOp.getBody().front()) {
|
for (Operation& op : funcOp.getBody().front()) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
|
||||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
@@ -229,7 +229,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
totalCrossbarCount += perInstanceCrossbarCount;
|
totalCrossbarCount += perInstanceCrossbarCount;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
|
||||||
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
||||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||||
@@ -353,7 +353,17 @@ public:
|
|||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
|
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
mergeTriviallyConnectedComputes(func);
|
mergeTriviallyConnectedComputes(func);
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
|
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
||||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||||
@@ -367,8 +377,8 @@ public:
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
if (failed(verifyScheduledSpatialInvariants(func))) {
|
||||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
||||||
|
std::unique_ptr<mlir::Pass> createSpatialLayoutPlanningPass();
|
||||||
|
std::unique_ptr<mlir::Pass> createLowerSpatialPlansPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,8 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
|||||||
void PimAccelerator::registerPasses(int optLevel) const {
|
void PimAccelerator::registerPasses(int optLevel) const {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||||
registerPass(createONNXToSpatialPass);
|
registerPass(createONNXToSpatialPass);
|
||||||
|
registerPass(createSpatialLayoutPlanningPass);
|
||||||
|
registerPass(createLowerSpatialPlansPass);
|
||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createPimBufferizationPass);
|
registerPass(createPimBufferizationPass);
|
||||||
|
|||||||
Reference in New Issue
Block a user