7 Commits

Author SHA1 Message Date
NiccoloN 568fd90542 cose
Validate Operations / validate-operations (push) Waiting to run
2026-06-25 18:57:12 +02:00
ilgeco be0bcc9dcc E' ancora tutto rotto
Validate Operations / validate-operations (push) Waiting to run
2026-06-25 16:24:14 +02:00
ilgeco 62dd40ee89 DeadLock 2026-06-24 15:52:07 +02:00
ilgeco 2b4115699a Convolutions support 2026-06-18 11:00:21 +02:00
ilgeco 3a985b3675 Different type of convolution 2026-06-18 10:59:02 +02:00
ilgeco 4ab24eb288 peft cost model 2026-06-18 10:57:59 +02:00
ilgeco e083c27d80 Add register reuse + peft scheduler cost model + Useless merger 2026-06-18 10:56:57 +02:00
64 changed files with 28408 additions and 1050 deletions
@@ -258,24 +258,23 @@ where
let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width();
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!(
crossbar_byte_width % size_of::<M>() == 0,
"M not divisor of the crosbbar size"
);
let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
let crossbar_stored_bytes = crossbar.stored_bytes();
let bytes_per_column = crossbar_height * size_of::<M>();
ensure!(bytes_per_column != 0, "crossbar height can not be zero");
ensure!(
crossbar_stored_bytes % bytes_per_column == 0,
"Stored crossbar bytes do not describe an integral number of columns"
);
let crossbar_elem_width = crossbar_stored_bytes / bytes_per_column;
ensure!(crossbar_elem_width != 0, "Crossbar contains no stored columns");
let loads = memory
.reserve_load(r1_val, crossbar_height * size_of::<F>())?
.execute_load::<F>()?;
let load = loads[0];
let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
let matrix = crossbar.load::<M>(crossbar_stored_bytes)?[0];
// --- FAER IMPLEMENTATION ---
+61
View File
@@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (result) {
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
}
return yieldedValue;
}
}
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
@@ -512,6 +528,24 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
continue;
}
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion();
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(selectedRegion.front().getTerminator());
if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands())
return mlir::failure();
value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
@@ -622,6 +656,33 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
continue;
}
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto thenYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
auto elseYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands()
|| result.getResultNumber() >= elseYield.getNumOperands()) {
return mlir::failure();
}
auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber()));
auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber()));
if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base)
return mlir::failure();
auto condition = compileIndexValueImpl(ifOp.getCondition());
if (failed(condition))
return mlir::failure();
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset};
return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))};
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
+14
View File
@@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return {};
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app);
}
std::fstream openAppendedReportFile(const std::string& name) {
return openAppendedReportFileWithExtension(name, "txt");
}
std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
+2
View File
@@ -12,6 +12,8 @@ namespace onnx_mlir {
std::fstream openReportFile(const std::string& name);
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::fstream openAppendedReportFile(const std::string& name);
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::string formatReportMemory(uint64_t bytes);
struct ReportField {
+26 -2
View File
@@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio
++emittedInstructionCount;
if (coreJsonStream)
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
updateScalarRegisterCache(instruction);
}
void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const {
switch (instruction.opcode) {
case pim_binary::Opcode::sldi:
scalarRegisterValues[instruction.rd] = instruction.r2OrImm;
break;
case pim_binary::Opcode::sld:
case pim_binary::Opcode::sadd:
case pim_binary::Opcode::ssub:
case pim_binary::Opcode::smul:
case pim_binary::Opcode::saddi:
case pim_binary::Opcode::smuli:
scalarRegisterValues[instruction.rd].reset();
break;
default:
break;
}
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number");
auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate");
if (scalarRegisterValues[registerIndex] == immediateValue)
return;
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
instruction.rd = registerIndex;
instruction.r2OrImm = immediateValue;
emitInstruction(instruction);
}
+3
View File
@@ -9,6 +9,7 @@
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <array>
#include <fstream>
#include <limits>
#include <optional>
@@ -170,6 +171,7 @@ class PimCodeGen {
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
mutable std::array<std::optional<int32_t>, 256> scalarRegisterValues = {};
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge, batchLane);
@@ -177,6 +179,7 @@ class PimCodeGen {
size_t remapCoreId(size_t coreId) const;
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
+60
View File
@@ -32,6 +32,31 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
"pim-conv-lowering",
llvm::cl::desc("Convolution lowering strategy for PIM"),
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
llvm::cl::values(
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
"streamed-patch",
"Use streamed/chunked im2col rows without multi-position packing")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
"streamed-packed",
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
"output-channel-tiled",
"Force Conv lowering that relies on Gemm output-channel tiling")),
llvm::cl::values(
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
"tiled-2d",
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
llvm::cl::init(PimConvLoweringAuto),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
@@ -49,11 +74,46 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
"pim-conv-im2col-max-elements",
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
llvm::cl::init(1ull << 20),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
"pim-conv-stream-chunk-positions",
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
llvm::cl::init(1024),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
llvm::cl::desc("Emit a bounded Conv lowering report"),
llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimDetectCommunicationDeadlock(
"pim-detect-communication-deadlock",
llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder(
"pim-materialize-scalar-fanout-global-order",
llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimTraceCommunicationMaterialization(
"pim-trace-communication-materialization",
llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
+19
View File
@@ -30,19 +30,38 @@ typedef enum {
PimMemoryReportFull = 2,
} PimMemoryReportLevel;
typedef enum {
PimConvLoweringAuto = 0,
PimConvLoweringLegacy = 1,
PimConvLoweringDepthwise = 2,
PimConvLoweringPackedIm2Col = 3,
PimConvLoweringStreamedPatch = 4,
PimConvLoweringStreamedPacked = 5,
PimConvLoweringOutputChannelTiled = 6,
PimConvLoweringInputKTiled = 7,
PimConvLoweringTiled2D = 8,
} PimConvLoweringType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
extern llvm::cl::opt<PimConvLoweringType> pimConvLowering;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<bool> pimReportConvLowering;
extern llvm::cl::opt<bool> pimDetectCommunicationDeadlock;
extern llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder;
extern llvm::cl::opt<bool> pimTraceCommunicationMaterialization;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<uint64_t> pimConvIm2colMaxElements;
extern llvm::cl::opt<uint64_t> pimConvStreamChunkPositions;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
+2
View File
@@ -29,6 +29,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
pm.addPass(createSpatialLayoutPlanningPass());
pm.addPass(createLowerSpatialPlansPass());
pm.addPass(createMergeComputeNodesPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
}
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
SpatialLayoutPlanningPass.cpp
LowerSpatialPlansPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
@@ -9,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto createSpatGraphCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
@@ -163,29 +163,29 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto createSpatGraphComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
}
}
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute<NumInputs>(
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
return createSpatGraphComputeBatch(
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
@@ -89,18 +89,18 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
@@ -19,9 +19,11 @@ using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
if (!rankedType)
return false;
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
if (!sourceType)
return false;
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
return false;
return true;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
return false;
}
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
continue;
}
if (isWeightLikeComputeOperand(operand)) {
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
@@ -0,0 +1,409 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
struct RowStripPhysicalValue {
Value physicalValue;
RankedTensorType logicalType;
SmallVector<int64_t, 16> fragmentOffsets;
SmallVector<int64_t, 16> fragmentSizes;
std::string indexMap;
};
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
Value value) {
auto it = rowStripValues.find(value);
if (it == rowStripValues.end())
return failure();
return it->second;
}
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
Value physicalValue) {
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!logicalType)
return reconciliator.emitOpError("requires ranked logical output type"), failure();
RowStripPhysicalValue value;
value.physicalValue = physicalValue;
value.logicalType = logicalType;
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
value.indexMap = reconciliator.getIndexMap().str();
return value;
}
static FailureOr<Value>
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
auto computeOp =
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
return computeOp.getResult(0);
}
static FailureOr<Value>
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
return failure();
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
return failure();
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
return failure();
const int64_t rank = rowStripValue.logicalType.getRank();
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
const int64_t packedWidth = packedType.getDimSize(1);
const int64_t packedChannels = packedType.getDimSize(2);
if (fragmentCount != packedType.getDimSize(0))
return failure();
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
return failure();
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
return failure();
}
auto packedSliceType =
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto expandedType =
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
auto logicalFragmentType =
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {rowStripValue.logicalType},
fragmentCount,
{},
ValueRange {rowStripValue.physicalValue},
[&](detail::SpatComputeBatchBodyArgs args) {
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> packedSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
Value packedSlice = tensor::ExtractSliceOp::create(
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2},
{3}
});
Value transposeInit =
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
Value logicalFragment =
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
.getResult()[0];
SmallVector<OpFoldResult> logicalOffsets {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packedWidth)};
createParallelInsertSliceIntoBatchOutput(rewriter,
loc,
logicalFragment,
args.outputs.front(),
logicalOffsets,
logicalSizes,
getUnitStrides(rewriter, 4));
return success();
});
if (failed(batchOp))
return failure();
return batchOp->getResult(0);
}
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
StringRef getArgument() const override { return "lower-spatial-plans"; }
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
PatternRewriter rewriter(ctx);
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
return true;
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
signalPassFailure();
return false;
};
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
return;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
planOp,
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
/*emitRowStripLayout=*/true,
rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
if (failed(rowStripValue)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *rowStripValue;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered =
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected Spatial Conv plan");
signalPassFailure();
return;
}
rewriter.replaceOp(planOp, *lowered);
continue;
}
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
});
if (outputReconciliator == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
if (failed(lowered)) {
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
if (failed(output)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *output;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
continue;
}
rewriter.setInsertionPoint(planOp);
auto computeOp = createSpatCompute<1>(
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
});
rewriter.replaceOp(planOp, computeOp.getResults());
continue;
}
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(materializeOp, materializeOp.getInput());
continue;
}
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
signalPassFailure();
return;
}
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
if (failed(rowStripValue)) {
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
signalPassFailure();
return;
}
rewriter.setInsertionPoint(materializeOp);
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
if (failed(dense)) {
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
signalPassFailure();
return;
}
rewriter.replaceOp(materializeOp, *dense);
continue;
}
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
continue;
}
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
signalPassFailure();
return;
}
if (!eraseAfterLowering.contains(reconciliatorOp)) {
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
signalPassFailure();
return;
}
}
}
bool erasedAny = true;
while (erasedAny) {
erasedAny = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (!eraseAfterLowering.contains(&op))
continue;
if (!op.use_empty())
continue;
eraseAfterLowering.erase(&op);
rewriter.eraseOp(&op);
erasedAny = true;
}
}
if (!eraseAfterLowering.empty()) {
for (Operation& op : funcOp.getBody().front())
if (eraseAfterLowering.contains(&op))
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
signalPassFailure();
return;
}
ConversionTarget helperTarget(*ctx);
helperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
RewritePatternSet helperPatterns(ctx);
populateGemmPatterns(helperPatterns, ctx);
populateTransposePatterns(helperPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
FrozenRewritePatternSet nestedHelperPatterns([&] {
RewritePatternSet patterns(ctx);
populateGemmPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
return patterns;
}());
ConversionTarget nestedHelperTarget(*ctx);
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
tensor::TensorDialect,
linalg::LinalgDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect,
func::FuncDialect>();
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
SmallVector<Operation*> computeLikeOps;
funcOp.walk([&](Operation* op) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
computeLikeOps.push_back(op);
});
for (Operation* op : computeLikeOps) {
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
signalPassFailure();
return;
}
}
if (!verifyLogicalPhase("after nested helper conversions"))
return;
bool hasIllegalOps = false;
moduleOp.walk([&](Operation* op) {
if (isa<ONNXEntryPointOp>(op))
return;
if (isa<spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(op)
|| op->getDialect()->getNamespace() == "onnx") {
op->emitOpError("operation must not remain after LowerSpatialPlans");
hasIllegalOps = true;
}
});
if (hasIllegalOps)
signalPassFailure();
else
dumpModule(moduleOp, "spatial1_premerge");
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
return;
}
};
} // namespace
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
} // namespace onnx_mlir
@@ -18,6 +18,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ONNXToSpatialVerifier.hpp"
using namespace mlir;
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|| !materializers.empty()) {
return;
}
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {
@@ -1,4 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
@@ -13,6 +15,8 @@ namespace onnx_mlir {
namespace {
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
illegalOp->emitOpError()
<< kPhaseMarker
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
});
return;
}
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
return candidate && (&region == candidate || region.isAncestor(candidate));
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
bool isValueDefinedInsideRegion(Value value, Region& region) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
if (Operation* definingOp = value.getDefiningOp())
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
return false;
}
bool isLegalExternalCapture(Value value, Region& region) {
if (isValueDefinedInsideRegion(value, region))
return true;
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
template <typename ComputeOpTy>
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
Region& body = compute.getBody();
body.walk([&](Operation* nestedOp) {
for (OpOperand& operand : nestedOp->getOpOperands()) {
Value value = operand.get();
if (isLegalExternalCapture(value, body))
continue;
Operation* definingOp = value.getDefiningOp();
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag =
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
diag << " (type " << value.getType() << ")";
if (definingOp)
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (Operation* owner = blockArg.getOwner()->getParentOp())
diag.attachNote(owner->getLoc())
<< "external block argument belongs to " << owner->getName().getStringRef();
}
});
}
});
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
template <typename ComputeOpTy>
void verifyScheduledInputs(ComputeOpTy compute,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag = illegalOp->emitOpError()
<< kPhaseMarker << " " << kind << " input #" << inputIndex
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp,
spatial::SpatGraphCompute,
spatial::SpatGraphComputeBatch,
spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(&op)) {
continue;
}
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
});
continue;
}
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker
<< " explicit channel communication is not expected before merge materialization";
});
continue;
}
if (isCompileTimeOp(&op))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError()
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
});
}
}
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
});
}
});
}
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyLogicalTopLevelOps(funcOp, diagnostics);
checkWeightUseChains(funcOp, diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
return success(!diagnostics.hasFailure());
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyScheduledTopLevelOps(funcOp, diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
return success(!diagnostics.hasFailure());
}
@@ -6,6 +6,8 @@
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector)
return failure();
const bool hasNonSingletonOutputBatch =
!shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1;
if (hasNonSingletonOutputBatch)
return failure();
Location loc = matmulOp.getLoc();
@@ -991,7 +996,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
gemmResult =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
.getResult();
rewriter.replaceOp(matmulOp, gemmResult);
if (shapeInfo->outputBatchShape.empty()) {
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
auto directOutType =
RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc);
Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
}
};
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
auto reluPlan = spatial::SpatReluPlanOp::create(
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
rewriter.replaceOp(reluOp, reluPlan.getResult());
return success();
}
};
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
if (failed(laneCountAttr))
return failure();
auto newCompute = spatial::SpatComputeBatch::create(
auto newCompute = spatial::SpatGraphComputeBatch::create(
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include <optional>
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
mlir::FailureOr<mlir::Value>
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
std::optional<mlir::Value> rowStripInput,
bool emitRowStripLayout,
mlir::PatternRewriter& rewriter);
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
} // namespace onnx_mlir
@@ -0,0 +1,206 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
namespace onnx_mlir {
namespace {
static constexpr StringLiteral kLogicalLayout = "nchw";
static constexpr StringLiteral kDenseLayout = "dense_nchw";
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
enum class SelectedLayout {
DenseNchw,
NchwRowStrip,
};
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
auto it = layouts.find(value);
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
}
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
return false;
}
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
for (Operation* user : value.getUsers()) {
if (usesSelectedRowStrip(user, layouts))
continue;
// Dense-only users must be materialized explicitly.
continue;
}
return true;
}
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
const int64_t channels = type.getDimSize(1);
const int64_t height = type.getDimSize(2);
const int64_t width = type.getDimSize(3);
offsets.reserve(height * 4);
sizes.reserve(height * 4);
for (int64_t row = 0; row < height; ++row) {
offsets.append({0, 0, row, 0});
sizes.append({1, channels, 1, width});
}
return {offsets, sizes};
}
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
if (inputLayout == SelectedLayout::NchwRowStrip)
return succeeded(canConsumeAndProduceRowStrip(convPlan));
return succeeded(canLowerConvPlanToRowStrip(convPlan));
}
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (!canSelectConvRowStrip(convPlan, layouts))
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
return SelectedLayout::DenseNchw;
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
return SelectedLayout::DenseNchw;
return SelectedLayout::NchwRowStrip;
}
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
auto outputType = cast<RankedTensorType>(value.getType());
auto [offsets, sizes] = buildRowStripMetadata(outputType);
return spatial::SpatReconciliatorOp::create(rewriter,
value.getLoc(),
outputType,
value,
ValueRange {},
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getDenseI64ArrayAttr(offsets),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getStringAttr(kRowStripIndexMap),
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
}
static void materializeDenseUses(IRRewriter& rewriter,
Value layoutValue,
llvm::DenseMap<Value, SelectedLayout>& layouts) {
SmallVector<OpOperand*> denseUses;
for (OpOperand& use : layoutValue.getUses()) {
if (usesSelectedRowStrip(use.getOwner(), layouts))
continue;
denseUses.push_back(&use);
}
for (OpOperand* use : denseUses) {
Operation* owner = use->getOwner();
rewriter.setInsertionPoint(owner);
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
owner->getLoc(),
use->get().getType(),
use->get(),
rewriter.getStringAttr(kLogicalLayout),
rewriter.getStringAttr(kRowStripLayout),
rewriter.getStringAttr(kDenseLayout));
use->set(materialized.getResult());
}
}
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
StringRef getArgument() const override { return "spatial-layout-planning"; }
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
void runOnOperation() override {
auto entryFunc = getPimEntryFunc(getOperation());
if (failed(entryFunc)) {
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
llvm::DenseMap<Value, SelectedLayout> layouts;
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
if (layouts[convPlan.getResult()] != selected) {
layouts[convPlan.getResult()] = selected;
changed = true;
}
continue;
}
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
if (layouts[reluPlan.getResult()] != selected) {
layouts[reluPlan.getResult()] = selected;
changed = true;
}
continue;
}
}
}
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
Value producedValue;
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
producedValue = convPlan.getResult();
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
producedValue = reluPlan.getResult();
else
continue;
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
continue;
rewriter.setInsertionPointAfter(&op);
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
} // namespace onnx_mlir
@@ -5,6 +5,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include <limits>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
@@ -26,7 +27,83 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
});
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
static bool isMaterializableExternalTensorOp(Operation* op) {
return isa<spatial::SpatChannelReceiveOp,
spatial::SpatExtractRowsOp,
tensor::ExtractSliceOp,
tensor::ExpandShapeOp,
tensor::CollapseShapeOp>(op);
}
//TODO REMOVE THIS UGLY FIX
//TODO: Remove this helper once compute_batch external tensor captures are
// fixed at the producer side.
//
// This function is a temporary SpatialToPim repair path. It clones selected
// external tensor producers, such as channel_receive and tensor view/slice ops,
// into the new pim.core_batch body when the old spat.compute_batch body refers
// to tensor values defined outside the batch.
//
// The real invariant should be stronger:
//
// A spat.compute_batch body must not capture external tensor values.
// Every tensor used inside the body must be either:
// - a compute_batch block argument,
// - defined inside the compute_batch body,
// - or a legal constant-like value.
//
// If this invariant is violated, the responsible producer, most likely merge
// schedule materialization, should emit verifier-clean Spatial IR instead of
// relying on SpatialToPim to clone external producer chains later.
//
// After that producer-side fix:
// 1. remove isMaterializableExternalTensorOp,
// 2. remove materializeExternalTensorValue,
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
// tensor operand,
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
// before SpatialToPim.
//
// Be careful not to replace every external tensor capture with a normal
// compute_batch input blindly: host-backed tensors and explicit inter-core
// communication have different semantics. In particular, channel_receive-like
// values should be materialized through the communication model, not silently
// treated as host inputs.
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
Location loc,
Block& oldBlock,
Value value,
IRMapping& mapper) {
if (mapper.contains(value))
return mapper.lookup(value);
if (!isa<TensorType>(value.getType()))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
return failure();
if (definingOp->getBlock() == &oldBlock)
return failure();
if (!isMaterializableExternalTensorOp(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
if (succeeded(materializedOperand))
mapper.map(operand, *materializedOperand);
}
Operation* cloned = rewriter.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
@@ -62,26 +139,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
Location loc,
ShapedType destinationType,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<int64_t> additionalOffsets,
IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
}
else {
scaledOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
(intAttr.getInt() + additionalOffsets[dim]) * scale);
} else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
if (additionalOffsets[dim] != 0) {
Value staticOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
additionalOffsets[dim] * scale);
scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult();
}
}
totalOffset =
@@ -93,9 +193,130 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
return totalOffset;
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
SmallVector<int64_t> zeroOffsets(destinationType.getRank(), 0);
return createHostTargetOffset(rewriter,
insertSlice.getLoc(),
destinationType,
insertSlice.getMixedOffsets(),
zeroOffsets,
mapper);
}
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
Location loc,
ArrayRef<OpFoldResult> baseOffsets,
ArrayRef<int64_t> fragmentOffsets,
IRMapping& mapper) {
SmallVector<OpFoldResult, 4> combined;
combined.reserve(fragmentOffsets.size());
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
int64_t base = cast<IntegerAttr>(attr).getInt();
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
continue;
}
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
if (fragmentOffsets[dim] == 0) {
combined.push_back(dynamicBase);
continue;
}
Value staticOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
}
return combined;
}
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
Value hostTarget,
ArrayRef<OpFoldResult> baseOffsets,
IRMapping& mapper) {
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return reconciliator.emitOpError(
"fragment assembly lowering requires explicit operand indices and unit strides");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
reconciliator.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
hostTarget = tensor::InsertSliceOp::create(rewriter,
reconciliator.getLoc(),
fragment,
hostTarget,
buildFragmentOffsets(rewriter,
reconciliator.getLoc(),
baseOffsets,
fragmentOffsets,
mapper),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
return hostTarget;
}
} // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
@@ -131,8 +352,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
if (result.use_empty())
continue;
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
@@ -195,6 +418,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
for (Operation* user : reconciliator.getOutput().getUsers()) {
if (!isa<tensor::ParallelInsertSliceOp>(user))
return reconciliator.emitOpError(
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
}
continue;
}
}
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
if (!firstOutputArg)
@@ -211,10 +446,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
continue;
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
if (auto reconciliator =
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
reconciliator,
hostTarget,
insertSlice.getMixedOffsets(),
mapper);
if (failed(updatedHostTarget))
return failure();
hostOutputTensors[resultIndex] = *updatedHostTarget;
continue;
}
}
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
@@ -264,9 +517,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
continue;
InFlightDiagnostic diagnostic =
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
if (definingOp)
diagnostic << " from external producer '" << definingOp->getName() << "'";
return diagnostic;
}
Operation* cloned = rewriter.clone(op, mapper);
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument;
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg;
}
else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg;
}
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner);
}
@@ -30,6 +30,91 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op);
}
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
Location loc,
ShapedType destinationType,
ArrayRef<int64_t> fragmentOffsets) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
int64_t byteOffset = 0;
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
byteOffset += offset * strides[dim] * elementBytes;
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
}
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
spatial::SpatReconciliatorOp reconciliator,
IRMapping& mapping) {
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<StringRef> modeAttr = reconciliator.getMode();
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
llvm::append_range(fragmentOperands, reconciliator.getFragments());
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return reconciliator.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
auto sourceType = dyn_cast<ShapedType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t fragmentBytes =
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
reconciliator.getOperation(),
fragmentBytes,
"fragment assembly host copy size");
if (failed(sizeAttr))
return failure();
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
packedFragmentOrdinal * fragmentBytes);
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
reconciliator.getLoc(),
currentOutput.getType(),
hostTargetOffset,
deviceSourceOffset,
currentOutput,
source,
*sizeAttr)
.getOutput();
}
return currentOutput;
}
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
@@ -55,7 +140,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
}
}
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
auto checkedCoreId =
@@ -66,7 +151,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
return *checkedCoreId;
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
@@ -104,13 +189,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
@@ -131,6 +216,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) {
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
std::optional<StringRef> modeAttr = reconciliator.getMode();
if (modeAttr && *modeAttr == "fragment_assembly") {
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
if (failed(lowered))
return false;
mapping.map(reconciliator.getOutput(), *lowered);
continue;
}
}
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
@@ -145,7 +241,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
Location loc = computeOp->getLoc();
@@ -1,6 +1,10 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -11,6 +15,107 @@ namespace raptor {
} // namespace raptor
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
SmallVector<OpFoldResult, 4> attrs;
attrs.reserve(values.size());
for (int64_t value : values)
attrs.push_back(builder.getIndexAttr(value));
return attrs;
}
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
SmallVector<OpFoldResult, 4> strides;
strides.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
strides.push_back(builder.getIndexAttr(1));
return strides;
}
struct LowerFragmentAssemblyReconciliatorPattern
: OpConversionPattern<spatial::SpatReconciliatorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
std::optional<StringRef> modeAttr = op.getMode();
if (!modeAttr || *modeAttr != "fragment_assembly")
return failure();
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
if (!resultType || !resultType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
if (!operandIndicesAttr || !fragmentStridesAttr)
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
int64_t rank = resultType.getRank();
SmallVector<Value> fragmentOperands {adaptor.getInput()};
llvm::append_range(fragmentOperands, adaptor.getFragments());
Value currentOutput =
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
return op.emitOpError("fragment assembly operand index is out of range");
SmallVector<int64_t, 4> fragmentOffsets;
int64_t fragmentElements = 1;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
if (flatStrides[flatIndex] != 1)
return op.emitOpError("fragment assembly lowering only supports unit strides");
fragmentOffsets.push_back(flatOffsets[flatIndex]);
fragmentElements *= flatSizes[flatIndex];
}
Value source = fragmentOperands[operandIndex];
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape())
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
SmallVector<int64_t, 4> fragmentShape;
fragmentShape.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim)
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
Value fragment = source;
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
SmallVector<int64_t, 4> extractOffsets(rank, 0);
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
fragment = tensor::ExtractSliceOp::create(rewriter,
op.getLoc(),
source,
getStaticIndexAttrs(rewriter, extractOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank));
}
currentOutput = tensor::InsertSliceOp::create(rewriter,
op.getLoc(),
fragment,
currentOutput,
getStaticIndexAttrs(rewriter, fragmentOffsets),
getStaticIndexAttrs(rewriter, fragmentShape),
getUnitStrides(rewriter, rank))
.getResult();
}
rewriter.replaceOp(op, currentOutput);
return success();
}
};
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -10,6 +10,14 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
for (NamedAttribute attr : source->getAttrs()) {
StringRef name = attr.getName().strref();
if (name.starts_with("raptor."))
target->setAttr(attr.getName(), attr.getValue());
}
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
if (failed(sizeAttr))
return failure();
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
rewriter.eraseOp(op);
return success();
}
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
if (failed(sizeAttr))
return failure();
Value received = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
.getOutput();
auto receive = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
Value received = receive.getOutput();
rewriter.replaceOp(op, received);
return success();
}
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure();
}
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
else {
{
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex)
return failure();
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
@@ -212,7 +212,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
@@ -643,7 +643,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
}
@@ -656,7 +656,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
@@ -669,7 +669,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
markOpToRemove(computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
@@ -25,9 +25,11 @@
#include <cassert>
#include <utility>
#include "Common/IR/ShapeUtils.hpp"
#include "Common/IR/ConstantUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp"
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
.getOutput();
}
static bool isHostBackedMemRefValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
return isa<memref::GetGlobalOp>(definingOp);
}
return false;
}
static bool isHostBackedTensorValue(Value value) {
while (Operation* definingOp = value.getDefiningOp()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return false;
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
extractSliceOp.getMixedOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
return false;
}
value = extractSliceOp.getSource();
continue;
}
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
return isHostBackedMemRefValue(toTensorOp.getBuffer());
return false;
}
return false;
}
static FailureOr<Value>
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType());
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
if (failed(sizeAttr))
return failure();
if (isHostBackedTensorValue(vector)) {
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
.getOutput();
}
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
}
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return;
}
func::FuncOp funcOp = *entryFunc;
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
funcOp.emitOpError(
"RAPTOR_PHASE_CHECK scheduled Spatial verification failed at the start of SpatialToPim");
signalPassFailure();
return;
}
IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
return;
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
signalPassFailure();
return;
}
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
signalPassFailure();
return;
}
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
};
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
@@ -41,8 +41,11 @@ private:
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
lowerComputeOp(spatial::SpatScheduledCompute computeOp,
mlir::IRRewriter& rewriter,
mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult {
Handled,
@@ -51,7 +54,7 @@ private:
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
mlir::IRRewriter& rewriter);
@@ -13,10 +13,13 @@ using namespace bufferization;
namespace onnx_mlir::pim {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue,
Location loc,
RewriterBase& rewriter,
const StaticValueKnowledge& knowledge) {
bool isContiguous =
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
succeeded(resolveContiguousAddress(memrefValue, knowledge)) || succeeded(compileContiguousAddressExpr(memrefValue));
if (isContiguous && isDeviceLocalPimAddress(memrefValue, knowledge))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -32,7 +35,7 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
if (failed(sizeAttr))
return failure();
if (isHostBackedPimAddress(memrefValue)) {
if (isHostBackedPimAddress(memrefValue, knowledge)) {
return PimMemCopyHostToDevOp::create(
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
.getOutput();
@@ -3,10 +3,15 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir::pim {
llvm::FailureOr<mlir::Value>
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
materializeContiguousInputMemRef(mlir::Value memrefValue,
mlir::Location loc,
mlir::RewriterBase& rewriter,
const onnx_mlir::StaticValueKnowledge& knowledge = {});
mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
@@ -15,6 +15,26 @@ using namespace bufferization;
namespace onnx_mlir {
namespace pim {
static StaticValueKnowledge getEnclosingBufferizationKnowledge(Operation* op) {
StaticValueKnowledge knowledge;
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
knowledge.indexValues[coreBatchOp.getLaneArgument()] = 0;
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
return knowledge;
}
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
}
return knowledge;
}
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op,
@@ -148,7 +168,8 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt))
return failure();
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguous =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguous))
return failure();
inputs.push_back(*contiguous);
@@ -182,7 +203,8 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput))
return failure();
@@ -410,7 +432,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt))
return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -456,7 +479,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt))
return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -497,10 +521,12 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt))
return failure();
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
auto contiguousLhs =
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
auto contiguousRhs =
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -534,10 +560,12 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt))
return failure();
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
auto contiguousLhs =
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousLhs))
return failure();
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
auto contiguousRhs =
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousRhs))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -574,7 +602,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt))
return failure();
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
auto contiguousInput =
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
if (failed(contiguousInput))
return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
@@ -116,6 +116,36 @@ lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const
return success();
}
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyHostToDevOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getHostSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getDeviceTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getHostSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceTarget(), knowledge);
if (!sourceIsHost || !targetIsDevice || targetIsHost || sourceIsDevice)
return copyOp.emitOpError("pim.memcp_hd requires a host-backed source and a device-local target");
return success();
}
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyDevToHostOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getDeviceSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getHostTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getHostTarget(), knowledge);
if (!targetIsHost || !sourceIsDevice || sourceIsHost || targetIsDevice)
return copyOp.emitOpError("pim.memcp_dh requires a device-local source and a host-backed target");
return success();
}
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyOp copyOp, const StaticValueKnowledge& knowledge) {
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
if (!sourceIsDevice || !targetIsDevice || sourceIsHost || targetIsHost)
return copyOp.emitOpError("pim.memcp requires device-local source and target operands");
return success();
}
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
StringRef getArgument() const override { return "bufferize-pim"; }
@@ -129,6 +159,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
LogicalResult verifyPimCopyAddressSpaces(ModuleOp moduleOp) const;
};
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
@@ -240,6 +271,10 @@ void PimBufferizationPass::runOnOperation() {
signalPassFailure();
return;
}
if (failed(verifyPimCopyAddressSpaces(moduleOp))) {
signalPassFailure();
return;
}
annotateWeightsMemrefs(moduleOp, funcOp);
@@ -346,6 +381,31 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
return success();
}
LogicalResult PimBufferizationPass::verifyPimCopyAddressSpaces(ModuleOp moduleOp) const {
bool hasFailure = false;
auto verifyWithKnowledge = [&](auto coreLikeOp, const StaticValueKnowledge& initialKnowledge) {
(void) walkPimCoreBlockStructurally(
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(&op); copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(&op);
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(&op);
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
hasFailure = true;
return success();
});
};
moduleOp.walk([&](pim::PimCoreOp coreOp) { verifyWithKnowledge(coreOp, seedCoreKnowledge(coreOp)); });
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, 0);
verifyWithKnowledge(coreBatchOp, knowledge);
});
return success(!hasFailure);
}
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir
@@ -96,8 +96,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
if (!mapOp->getParentOfType<pim::PimCoreOp>() && !mapOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
@@ -5,6 +5,7 @@ add_pim_library(OMPimVerification
LINK_LIBS PUBLIC
OMPimCommon
OMPimCompilerOptions
OMPimBufferization
PimOps
SpatialOps
@@ -5,12 +5,17 @@
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
@@ -143,6 +148,479 @@ static bool isHostAddressableValue(Value value, const StaticValueKnowledge& know
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
}
enum class CommunicationEventKind { Send, Receive };
struct CommunicationEvent {
CommunicationEventKind kind = CommunicationEventKind::Send;
int64_t coreId = 0;
int64_t peerCoreId = 0;
int64_t size = 0;
uint64_t ordinal = 0;
std::optional<int64_t> minChannelId;
std::string materializer;
std::optional<int64_t> traceId;
std::optional<int64_t> commOrder;
std::optional<int64_t> traceClassId;
std::optional<int64_t> traceBlockOrdinal;
std::string traceKind;
std::string tracePhase;
std::string traceClassKind;
std::string tracePayload;
std::string traceMessages;
std::string tracePrevOp;
std::string traceNextOp;
Operation* op = nullptr;
};
using CommunicationEventVector = SmallVector<CommunicationEvent, 0>;
static StringRef getCommunicationEventKindName(CommunicationEventKind kind) {
return kind == CommunicationEventKind::Send ? "send" : "receive";
}
constexpr StringLiteral kRaptorMinChannelIdAttr = "raptor.min_channel_id";
constexpr StringLiteral kRaptorMaterializerAttr = "raptor.materializer";
constexpr StringLiteral kRaptorCommOrderAttr = "raptor.comm_order";
constexpr StringLiteral kRaptorCommTraceIdAttr = "raptor.comm_trace_id";
constexpr StringLiteral kRaptorCommTraceKindAttr = "raptor.comm_trace_kind";
constexpr StringLiteral kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase";
constexpr StringLiteral kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id";
constexpr StringLiteral kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind";
constexpr StringLiteral kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal";
constexpr StringLiteral kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload";
constexpr StringLiteral kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages";
constexpr StringLiteral kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op";
constexpr StringLiteral kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op";
static std::optional<int64_t> getNearestIntegerAttr(Operation* op, StringRef name) {
for (Operation* current = op; current; current = current->getParentOp())
if (auto attr = current->getAttrOfType<IntegerAttr>(name))
return attr.getInt();
return std::nullopt;
}
static std::string getNearestStringAttr(Operation* op, StringRef name) {
for (Operation* current = op; current; current = current->getParentOp())
if (auto attr = current->getAttrOfType<StringAttr>(name))
return attr.getValue().str();
return "";
}
static std::string formatLocation(Location loc) {
std::string text;
llvm::raw_string_ostream os(text);
loc.print(os);
return os.str();
}
static std::string formatOperationSummary(Operation* op) {
std::string text;
llvm::raw_string_ostream os(text);
OpPrintingFlags flags;
flags.skipRegions();
flags.elideLargeElementsAttrs();
op->print(os, flags);
return os.str();
}
static std::string formatCommunicationEvent(const CommunicationEvent& event) {
std::string text;
llvm::raw_string_ostream os(text);
os << "core " << event.coreId << " " << getCommunicationEventKindName(event.kind) << " "
<< (event.kind == CommunicationEventKind::Send ? "to" : "from") << " " << event.peerCoreId
<< " size " << event.size << "B ordinal " << event.ordinal;
if (event.minChannelId)
os << " min_channel " << *event.minChannelId;
if (event.commOrder)
os << " comm_order " << *event.commOrder;
if (!event.materializer.empty())
os << " materializer " << event.materializer;
if (event.traceId)
os << " trace#" << *event.traceId;
if (!event.tracePhase.empty())
os << " phase " << event.tracePhase;
if (event.traceClassId)
os << " class " << event.traceClassKind << "#" << *event.traceClassId;
if (event.traceBlockOrdinal)
os << " block_ordinal " << *event.traceBlockOrdinal;
if (!event.tracePayload.empty())
os << " payload " << event.tracePayload;
if (!event.traceMessages.empty())
os << " messages {" << event.traceMessages << "}";
if (!event.tracePrevOp.empty() || !event.traceNextOp.empty())
os << " inserted_between [" << event.tracePrevOp << " | " << event.traceNextOp << "]";
return os.str();
}
static bool areMatchedCommunicationEvents(const CommunicationEvent& lhs, const CommunicationEvent& rhs) {
if (lhs.coreId != rhs.peerCoreId || lhs.peerCoreId != rhs.coreId || lhs.size != rhs.size)
return false;
return (lhs.kind == CommunicationEventKind::Send && rhs.kind == CommunicationEventKind::Receive)
|| (lhs.kind == CommunicationEventKind::Receive && rhs.kind == CommunicationEventKind::Send);
}
static std::optional<size_t> findMatchingCounterpartIndex(const CommunicationEventVector& events,
const CommunicationEvent& event,
size_t begin) {
for (size_t index = begin; index < events.size(); ++index)
if (areMatchedCommunicationEvents(event, events[index]))
return index;
return std::nullopt;
}
static void printCounterpartProbe(llvm::raw_ostream& os,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
const CommunicationEvent& blockedEvent) {
auto peerEventsIt = coreEvents.find(blockedEvent.peerCoreId);
if (peerEventsIt == coreEvents.end()) {
os << " no local stream was collected for peer core " << blockedEvent.peerCoreId << "\n";
return;
}
const CommunicationEventVector& peerEvents = peerEventsIt->second;
size_t peerPc = 0;
auto peerPcIt = programCounters.find(blockedEvent.peerCoreId);
if (peerPcIt != programCounters.end())
peerPc = peerPcIt->second;
os << " counterpart probe for " << formatCommunicationEvent(blockedEvent) << "\n";
os << " peer core " << blockedEvent.peerCoreId << " current pc " << peerPc << " of " << peerEvents.size()
<< "\n";
std::optional<size_t> nextMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, peerPc);
std::optional<size_t> anyMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, 0);
if (!nextMatch && !anyMatch) {
os << " no matching counterpart exists anywhere in the peer stream\n";
return;
}
if (!nextMatch && anyMatch) {
os << " matching counterpart exists only before the peer pc at ordinal " << *anyMatch
<< "; this usually means the static stream expansion or ordering metadata is inconsistent\n";
os << " " << formatCommunicationEvent(peerEvents[*anyMatch]) << "\n";
os << " op: " << formatOperationSummary(peerEvents[*anyMatch].op) << "\n";
return;
}
const CommunicationEvent& match = peerEvents[*nextMatch];
os << " next matching counterpart is at peer ordinal " << *nextMatch << " (distance +"
<< (*nextMatch >= peerPc ? *nextMatch - peerPc : 0) << ")\n";
os << " " << formatCommunicationEvent(match) << "\n";
os << " op: " << formatOperationSummary(match.op) << "\n";
if (*nextMatch == peerPc)
return;
os << " peer operations blocking before that counterpart:\n";
size_t end = std::min(peerEvents.size(), std::min(*nextMatch + static_cast<size_t>(1), peerPc + static_cast<size_t>(12)));
for (size_t index = peerPc; index < end; ++index) {
os << (index == peerPc ? " pc => " : " ") << "#" << index << " "
<< formatCommunicationEvent(peerEvents[index]) << "\n";
os << " op: " << formatOperationSummary(peerEvents[index].op) << "\n";
}
if (end <= *nextMatch)
os << " ... " << (*nextMatch - end + 1) << " more peer communication event(s) before the counterpart\n";
}
static CommunicationEvent makeCommunicationEvent(CommunicationEventKind kind,
int64_t coreId,
int64_t peerCoreId,
int64_t size,
uint64_t ordinal,
Operation* op) {
return CommunicationEvent {kind,
coreId,
peerCoreId,
size,
ordinal,
getNearestIntegerAttr(op, kRaptorMinChannelIdAttr),
getNearestStringAttr(op, kRaptorMaterializerAttr),
getNearestIntegerAttr(op, kRaptorCommTraceIdAttr),
getNearestIntegerAttr(op, kRaptorCommOrderAttr),
getNearestIntegerAttr(op, kRaptorCommTraceClassIdAttr),
getNearestIntegerAttr(op, kRaptorCommTraceBlockOrdinalAttr),
getNearestStringAttr(op, kRaptorCommTraceKindAttr),
getNearestStringAttr(op, kRaptorCommTracePhaseAttr),
getNearestStringAttr(op, kRaptorCommTraceClassKindAttr),
getNearestStringAttr(op, kRaptorCommTracePayloadAttr),
getNearestStringAttr(op, kRaptorCommTraceMessagesAttr),
getNearestStringAttr(op, kRaptorCommTracePrevOpAttr),
getNearestStringAttr(op, kRaptorCommTraceNextOpAttr),
op};
}
static LogicalResult appendCoreCommunicationEvents(Block& block,
int64_t coreId,
const StaticValueKnowledge& initialKnowledge,
SmallVectorImpl<CommunicationEvent>& events,
pim::CappedDiagnosticReporter& diagnostics) {
return walkPimCoreBlock(block, initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto sendOp = dyn_cast<pim::PimSendOp>(&op)) {
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
if (failed(targetCoreId)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("cannot statically resolve send target core for PIM communication deadlock check");
});
return failure();
}
events.push_back(makeCommunicationEvent(CommunicationEventKind::Send,
coreId,
*targetCoreId,
sendOp.getSize(),
static_cast<uint64_t>(events.size()),
&op));
return success();
}
if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(&op)) {
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
if (failed(sourceCoreId)) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("cannot statically resolve receive source core for PIM communication deadlock check");
});
return failure();
}
events.push_back(makeCommunicationEvent(CommunicationEventKind::Receive,
coreId,
*sourceCoreId,
receiveOp.getSize(),
static_cast<uint64_t>(events.size()),
&op));
return success();
}
return success();
});
}
static void printCommunicationWindow(llvm::raw_ostream& os,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
int64_t coreId,
size_t pc,
unsigned radius = 4) {
auto eventsIt = coreEvents.find(coreId);
if (eventsIt == coreEvents.end())
return;
const CommunicationEventVector& events = eventsIt->second;
size_t begin = pc > radius ? pc - radius : 0;
size_t end = std::min(events.size(), pc + static_cast<size_t>(radius) + 1);
os << " local stream for core " << coreId << " around pc " << pc << " of " << events.size() << ":\n";
for (size_t index = begin; index < end; ++index) {
os << (index == pc ? " => " : " ") << formatCommunicationEvent(events[index]) << "\n";
os << " op: " << formatOperationSummary(events[index].op) << "\n";
}
}
static void printCommunicationDeadlockReport(const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
ArrayRef<int64_t> cycle) {
llvm::errs() << "\n=== PIM static communication deadlock report ===\n";
llvm::errs() << "wait cycle:";
for (int64_t coreId : cycle)
llvm::errs() << " " << coreId;
if (!cycle.empty())
llvm::errs() << " -> " << cycle.front();
llvm::errs() << "\n\nblocked heads:\n";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
const CommunicationEvent& event = eventsIt->second[pcIt->second];
llvm::errs() << " " << formatCommunicationEvent(event) << "\n";
llvm::errs() << " loc: " << formatLocation(event.op->getLoc()) << "\n";
llvm::errs() << " op : " << formatOperationSummary(event.op) << "\n";
}
llvm::errs() << "\npeer counterpart probes:\n";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
printCounterpartProbe(llvm::errs(), coreEvents, programCounters, eventsIt->second[pcIt->second]);
}
llvm::errs() << "\nlocal communication streams:\n";
for (int64_t coreId : cycle) {
auto pcIt = programCounters.find(coreId);
if (pcIt == programCounters.end())
continue;
printCommunicationWindow(llvm::errs(), coreEvents, coreId, pcIt->second);
}
llvm::errs() << "=== end PIM static communication deadlock report ===\n\n";
}
static void emitCommunicationDeadlockCycle(ModuleOp moduleOp,
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters,
ArrayRef<int64_t> cycle) {
printCommunicationDeadlockReport(coreEvents, programCounters, cycle);
auto diagnostic = moduleOp.emitError()
<< "PIM communication deadlock check found a blocking send/receive cycle while statically simulating the "
"expanded per-core communication streams; see the PIM static communication deadlock report above";
for (int64_t coreId : cycle) {
auto eventsIt = coreEvents.find(coreId);
auto pcIt = programCounters.find(coreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
continue;
const CommunicationEvent& event = eventsIt->second[pcIt->second];
Diagnostic& note = diagnostic.attachNote(event.op->getLoc());
note << formatCommunicationEvent(event);
if (!event.materializer.empty())
note << " emitted by " << event.materializer;
}
}
static FailureOr<SmallVector<int64_t>> findCommunicationWaitCycle(
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
const DenseMap<int64_t, size_t>& programCounters) {
for (const auto& [startCoreId, events] : coreEvents) {
auto startPcIt = programCounters.find(startCoreId);
if (startPcIt == programCounters.end() || startPcIt->second >= events.size())
continue;
DenseMap<int64_t, size_t> positionInPath;
SmallVector<int64_t, 8> path;
int64_t currentCoreId = startCoreId;
while (true) {
auto eventsIt = coreEvents.find(currentCoreId);
auto pcIt = programCounters.find(currentCoreId);
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
break;
auto positionIt = positionInPath.find(currentCoreId);
if (positionIt != positionInPath.end()) {
SmallVector<int64_t> cycle;
for (size_t index = positionIt->second; index < path.size(); ++index)
cycle.push_back(path[index]);
return cycle;
}
positionInPath[currentCoreId] = path.size();
path.push_back(currentCoreId);
currentCoreId = eventsIt->second[pcIt->second].peerCoreId;
}
}
return failure();
}
static LogicalResult verifyNoStaticCommunicationDeadlock(ModuleOp moduleOp,
pim::CappedDiagnosticReporter& diagnostics) {
DenseMap<int64_t, CommunicationEventVector> coreEvents;
bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
if (funcOp.isExternal())
continue;
for (Operation& op : funcOp.getBody().front().getOperations()) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
int64_t coreId = coreOp.getCoreId();
if (failed(appendCoreCommunicationEvents(
coreOp.getBody().front(), coreId, StaticValueKnowledge {}, coreEvents[coreId], diagnostics)))
hasFailure = true;
continue;
}
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
SmallVector<int32_t> coreIds = getBatchCoreIds(coreBatchOp);
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
for (size_t lane = 0; lane < laneCount; ++lane) {
StaticValueKnowledge laneKnowledge;
laneKnowledge.indexValues[coreBatchOp.getLaneArgument()] = static_cast<int64_t>(lane);
for (unsigned inputIndex = 0; inputIndex < coreBatchOp.getInputs().size(); ++inputIndex)
laneKnowledge.aliases[coreBatchOp.getInputArgument(inputIndex)] = coreBatchOp.getInputs()[inputIndex];
SmallVector<int32_t> laneCoreIds = getLaneChunkCoreIds(coreIds, laneCount, static_cast<unsigned>(lane));
for (int32_t coreId : laneCoreIds) {
if (failed(appendCoreCommunicationEvents(
coreBatchOp.getBody().front(), coreId, laneKnowledge, coreEvents[coreId], diagnostics)))
hasFailure = true;
}
}
}
}
}
if (hasFailure)
return failure();
DenseMap<int64_t, size_t> programCounters;
for (const auto& [coreId, events] : coreEvents)
programCounters[coreId] = 0;
while (true) {
bool madeProgress = false;
for (const auto& [coreId, events] : coreEvents) {
size_t pc = programCounters[coreId];
if (pc >= events.size())
continue;
const CommunicationEvent& event = events[pc];
auto peerEventsIt = coreEvents.find(event.peerCoreId);
if (peerEventsIt == coreEvents.end())
continue;
size_t peerPc = programCounters[event.peerCoreId];
if (peerPc >= peerEventsIt->second.size())
continue;
const CommunicationEvent& peerEvent = peerEventsIt->second[peerPc];
if (!areMatchedCommunicationEvents(event, peerEvent))
continue;
++programCounters[coreId];
++programCounters[event.peerCoreId];
madeProgress = true;
}
if (madeProgress)
continue;
bool allDone = true;
for (const auto& [coreId, events] : coreEvents) {
if (programCounters[coreId] < events.size()) {
allDone = false;
break;
}
}
if (allDone)
return success();
auto cycle = findCommunicationWaitCycle(coreEvents, programCounters);
if (succeeded(cycle)) {
emitCommunicationDeadlockCycle(moduleOp, coreEvents, programCounters, *cycle);
return failure();
}
auto diagnostic = moduleOp.emitError()
<< "PIM communication deadlock check stalled without finding a closed wait cycle; this usually means a "
"send/receive peer is missing or ordered after a finished core";
for (const auto& [coreId, events] : coreEvents) {
size_t pc = programCounters[coreId];
if (pc >= events.size())
continue;
const CommunicationEvent& event = events[pc];
diagnostic.attachNote(event.op->getLoc()) << formatCommunicationEvent(event);
}
return failure();
}
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -212,11 +690,18 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
}
}
bool hasFailure = false;
if (pimDetectCommunicationDeadlock && failed(verifyNoStaticCommunicationDeadlock(moduleOp, diagnostics)))
hasFailure = true;
if (diagnostics.hasFailure()) {
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure();
hasFailure = true;
}
if (hasFailure)
signalPassFailure();
}
private:
+127 -12
View File
@@ -26,7 +26,7 @@ def SpatTensor :
// Execution
//===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute",
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights";
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatComputeBatch : SpatOp<"compute_batch",
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
let extraClassDeclaration = [{
std::optional<::mlir::BlockArgument> getLaneArgument();
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
}
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch";
let summary = "Parallel combining terminator for resultful Spatial compute batches";
let regions = (region SizedRegion<1>:$region);
@@ -159,6 +192,88 @@ def SpatConcatOp : SpatOp<"concat", []> {
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Planning
//===----------------------------------------------------------------------===//
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
let arguments = (ins
SpatTensor:$input,
SpatTensor:$weight,
Optional<SpatTensor>:$bias,
DenseI64ArrayAttr:$pads,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
I64Attr:$group,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
let summary = "Layout-aware ReLU planning op";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatReconciliatorOp : SpatOp<"reconciliator", []> {
let summary = "Logical-to-physical layout record or explicit fragment assembly";
let arguments = (ins
SpatTensor:$input,
Variadic<SpatTensor>:$fragments,
StrAttr:$logicalLayout,
StrAttr:$physicalLayout,
DenseI64ArrayAttr:$fragmentOffsets,
DenseI64ArrayAttr:$fragmentSizes,
StrAttr:$indexMap,
OptionalAttr<StrAttr>:$mode,
OptionalAttr<DenseI64ArrayAttr>:$fragmentOperandIndices,
OptionalAttr<DenseI64ArrayAttr>:$fragmentStrides,
OptionalAttr<StrAttr>:$conflictPolicy,
OptionalAttr<StrAttr>:$coveragePolicy
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
let summary = "Explicit layout conversion or materialization barrier";
let arguments = (ins
SpatTensor:$input,
StrAttr:$logicalLayout,
StrAttr:$sourcePhysicalLayout,
StrAttr:$targetPhysicalLayout
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
+235 -118
View File
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
}
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
return weights;
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), getWeights().size() + idx);
template <typename ComputeOpTy>
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
return getBlockArgument(compute.getBody(), idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
template <typename ComputeOpTy>
std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
}
template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
auto index = std::distance(compute.getWeights().begin(), existing);
return {{*existing, *getComputeWeightArgument(compute, index)}};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
unsigned weightCount = compute.getWeights().size();
unsigned inputCount = compute.getInputs().size();
compute.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
template <typename ComputeBatchOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
auto index = std::distance(batch.getWeights().begin(), existing);
return {{*existing, *batch.getWeightArgument(index)}};
}
unsigned weightCount = batch.getWeights().size();
unsigned inputCount = batch.getInputs().size();
batch.getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
}
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
template <typename ComputeOpTy>
std::optional<std::tuple<Value, BlockArgument>>
insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
unsigned weightCount = compute.getWeights().size();
unsigned inputCount = compute.getInputs().size();
compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
template <typename ComputeOpTy>
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
for (unsigned index = 0; index < compute.getWeights().size(); ++index)
if (auto weightArg = compute.getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
for (unsigned index = 0; index < compute.getInputs().size(); ++index)
if (auto inputArg = compute.getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
}
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
template <typename ComputeOpTy>
FailureOr<std::tuple<OpResult, ComputeOpTy>>
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > compute.getNumResults())
return failure();
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
rewriter.setInsertionPoint(compute.getOperation());
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute =
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
newCompute->setAttrs(compute->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
compute.getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(compute.getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
}
template <typename ComputeBatchOpTy>
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > batch.getNumResults())
return failure();
rewriter.setInsertionPoint(batch.getOperation());
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
newBatch->setAttrs(batch->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
batch.getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(batch.getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
}
} // namespace
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
bool isAnySpatialComputeLike(Operation* op) {
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
}
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
return getComputeWeightArgument(*this, idx);
}
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
return insertComputeInput(*this, idx, input, loc);
}
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
}
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
};
}
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeBatchWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
std::optional<std::tuple<Value, BlockArgument>>
SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) {
setNameFn(*outputArg, "out");
continue;
}
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
}
}
for (unsigned index = 0; index < getWeights().size(); ++index)
if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + idx);
}
std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
return insertComputeBatchWeight(*this, idx, weight, loc);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
}
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
OpResult SpatInParallelOp::getParentResult(int64_t idx) {
Operation* parent = getOperation()->getParentOp();
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
return parent->getResult(idx);
}
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
+16
View File
@@ -26,3 +26,19 @@
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
namespace onnx_mlir {
namespace spatial {
bool isGraphComputeLike(mlir::Operation* op);
bool isGraphBatchComputeLike(mlir::Operation* op);
bool isScheduledComputeLike(mlir::Operation* op);
bool isScheduledBatchComputeLike(mlir::Operation* op);
bool isAnySpatialComputeLike(mlir::Operation* op);
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
using SpatCompute = SpatGraphCompute;
using SpatComputeBatch = SpatGraphComputeBatch;
} // namespace spatial
} // namespace onnx_mlir
+260 -251
View File
@@ -115,6 +115,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return success();
}
template <typename ComputeOpTy>
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeOpTy>
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
template <typename ComputeBatchOpTy>
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
auto laneArg = op.getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(op.getWeights().size());
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
auto weightArg = op.getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(op.getInputs().size());
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
auto inputArg = op.getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
if (op.getNumResults() != 0) {
outputArgs.reserve(op.getNumResults());
for (unsigned index = 0; index < op.getNumResults(); ++index) {
auto outputArg = op.getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << op.getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
if (op.getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, op.getResultTypes());
printer << " ";
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
}
template <typename ComputeBatchOpTy>
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
} // namespace
void SpatYieldOp::print(OpAsmPrinter& printer) {
@@ -218,260 +466,21 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
return success();
}
void SpatCompute::print(OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
printer.printOptionalAttrDict((*this)->getAttrs(),
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreId)
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
auto laneArg = getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " ";
printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount();
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
printBlockArgumentList(printer, outputArgs);
}
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
}
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
printCompressedTypeSequence(printer, getResultTypes());
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
auto& builder = parser.getBuilder();
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
result.addAttribute(
"operandSegmentSizes",
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
if (hasCoreIds)
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
return failure();
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
}
void SpatInParallelOp::print(OpAsmPrinter& printer) {
@@ -10,8 +10,9 @@ using namespace mlir;
namespace onnx_mlir {
namespace spatial {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
template <typename ComputeOpTy>
LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = compute.getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
results.push_back(compute.getOperand(blockArg.getArgNumber()));
continue;
}
}
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
return success();
}
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
return foldComputeLike(*this, results);
}
} // namespace spatial
} // namespace onnx_mlir
+370 -76
View File
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
return shapedType.getShape();
}
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
template <typename ComputeBatchOpTy>
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
auto blockArg = dyn_cast<BlockArgument>(value);
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
return success();
}
static bool isStaticIndexExpr(Value value) {
if (matchConstantIndexValue(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
return false;
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
}
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
return false;
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || matchConstantIndexValue(value))
if (value == laneArg || isStaticIndexExpr(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
}
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
if (addOp)
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
auto mulOp = value.getDefiningOp<arith::MulIOp>();
if (!mulOp)
return false;
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|| (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
}
static LogicalResult
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may not capture external values";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
<< "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
<< operand.getOperandNumber() << " type=" << value.getType()
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
: "scalar"));
if (Operation* definingOp = value.getDefiningOp())
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
else if (auto blockArg = dyn_cast<BlockArgument>(value))
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
<< blockArg.getOwner()->getParentOp()->getName() << "'";
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
template <typename ComputeBatchOpTy>
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -344,144 +380,399 @@ LogicalResult SpatConcatOp::verify() {
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute");
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
static bool isKnownPhysicalLayout(StringRef layout) {
return layout == "dense_nchw" || layout == "nchw_row_strip" || layout == "fragmented";
}
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!inputType || !outputType)
return op->emitOpError() << kind << " requires ranked tensor input and output types";
if (inputType.getElementType() != outputType.getElementType())
return op->emitOpError() << kind << " requires matching input/output element types";
return success();
}
LogicalResult SpatConv2DPlanOp::verify() {
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!inputType || !weightType || !outputType)
return emitError("requires ranked tensor input, weight, and output");
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
return emitError("requires rank-4 input, weight, and output tensors");
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (getPads().size() != 4)
return emitError("requires exactly four pad values");
if (getStrides().size() != 2)
return emitError("requires exactly two stride values");
if (getDilations().size() != 2)
return emitError("requires exactly two dilation values");
if (getGroup() < 1)
return emitError("requires group >= 1");
if (inputType.getElementType() != weightType.getElementType()
|| inputType.getElementType() != outputType.getElementType()) {
return emitError("requires matching input, weight, and output element types");
}
if (getBias()) {
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
if (!biasType)
return emitError("requires ranked tensor bias type");
if (biasType.getElementType() != outputType.getElementType())
return emitError("requires bias element type to match output element type");
}
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
LogicalResult SpatReluPlanOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
return success();
}
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
LogicalResult SpatReconciliatorOp::verify() {
auto modeAttr = getModeAttr();
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getPhysicalLayout()))
return emitError("requires a known physical layout");
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
if (!logicalType)
return emitError("requires ranked tensor output");
auto offsets = getFragmentOffsets();
auto sizes = getFragmentSizes();
if (offsets.size() != sizes.size())
return emitError("fragment offset and size arrays must have the same length");
int64_t rank = logicalType.getRank();
if (offsets.empty())
return success();
if (rank <= 0 || offsets.size() % rank != 0)
return emitError("fragment metadata must be a whole number of rank-sized fragments");
auto verifyBoundsOnly = [&](ArrayRef<int64_t> strideValues) -> LogicalResult {
ArrayRef<int64_t> shape = logicalType.getShape();
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
int64_t dim = index % rank;
int64_t offset = offsets[index];
int64_t size = sizes[index];
int64_t stride = strideValues.empty() ? 1 : strideValues[index];
if (offset < 0 || size < 0 || stride < 0)
return emitError("fragment offsets, sizes, and strides must be non-negative");
int64_t logicalDim = shape[dim];
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
return emitError("fragment bounds must stay within the logical tensor shape");
if (stride != 1)
return emitError("fragment assembly currently requires unit strides");
}
return success();
};
if (!isFragmentAssembly) {
if (failed(verifyBoundsOnly({})))
return failure();
if (!getFragments().empty())
return emitError("legacy reconciliator does not accept extra fragment operands");
if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr())
return emitError("legacy reconciliator does not accept fragment assembly attributes");
return success();
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
auto stridesAttr = getFragmentStridesAttr();
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
if (!operandIndicesAttr)
return emitError("fragment assembly reconciliator requires fragment operand indices");
if (!stridesAttr)
return emitError("fragment assembly reconciliator requires fragment strides");
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
if (strides.size() != offsets.size())
return emitError("fragment stride and offset arrays must have the same length");
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
return emitError("fragment assembly reconciliator requires conflict and coverage policies");
if (getConflictPolicy() != "disjoint")
return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\"");
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\"");
SmallVector<Value> operands;
operands.push_back(getInput());
llvm::append_range(operands, getFragments());
int64_t operandCount = static_cast<int64_t>(operands.size());
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
if (operandCount == 0)
return emitError("fragment assembly reconciliator requires at least one operand");
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
return emitError("fragment assembly metadata count must match operand count * result rank");
if (failed(verifyBoundsOnly(strides)))
return failure();
SmallVector<std::pair<SmallVector<int64_t, 4>, SmallVector<int64_t, 4>>, 8> slices;
slices.reserve(static_cast<size_t>(fragmentCount));
SmallVector<SmallVector<SmallVector<int64_t, 4>, 4>, 8> sizesByOperand(static_cast<size_t>(operandCount));
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
int64_t operandIndex = operandIndices[fragmentIndex];
if (operandIndex < 0 || operandIndex >= operandCount)
return emitError("fragment assembly operand index is out of range");
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
if (!operandType || !operandType.hasStaticShape())
return emitError("fragment assembly reconciliator requires static ranked tensor operands");
if (operandType.getRank() != rank)
return emitError("fragment assembly reconciliator requires operand/result rank match");
SmallVector<int64_t, 4> fragmentOffsets;
SmallVector<int64_t, 4> fragmentSizes;
fragmentOffsets.reserve(rank);
fragmentSizes.reserve(rank);
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t flatIndex = fragmentIndex * rank + dim;
fragmentOffsets.push_back(offsets[flatIndex]);
fragmentSizes.push_back(sizes[flatIndex]);
}
sizesByOperand[static_cast<size_t>(operandIndex)].push_back(fragmentSizes);
for (const auto& [existingOffsets, existingSizes] : slices) {
bool overlaps = true;
for (int64_t dim = 0; dim < rank; ++dim) {
int64_t begin = fragmentOffsets[dim];
int64_t end = begin + fragmentSizes[dim];
int64_t existingBegin = existingOffsets[dim];
int64_t existingEnd = existingBegin + existingSizes[dim];
if (end <= existingBegin || existingEnd <= begin) {
overlaps = false;
break;
}
}
if (overlaps)
return emitError("fragment assembly reconciliator requires disjoint static slices");
}
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
}
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
if (sizesByOperand[static_cast<size_t>(operandIndex)].empty())
return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment");
auto operandType = cast<RankedTensorType>(operands[operandIndex].getType());
ArrayRef<int64_t> operandShape = operandType.getShape();
auto& fragmentShapes = sizesByOperand[static_cast<size_t>(operandIndex)];
if (fragmentShapes.size() == 1) {
if (!llvm::equal(operandShape, fragmentShapes.front()))
return emitError("single-fragment reconciliator operand shape must match declared fragment size");
continue;
}
ArrayRef<int64_t> fragmentShape = fragmentShapes.front();
for (ArrayRef<int64_t> otherShape : fragmentShapes)
if (!llvm::equal(fragmentShape, otherShape))
return emitError("packed reconciliator operand requires equal fragment sizes per operand");
if (llvm::equal(operandShape, fragmentShape))
continue;
if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front()))
return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions");
if (operandShape.front() != static_cast<int64_t>(fragmentShapes.size()) * fragmentShape.front())
return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size");
}
if (getCoveragePolicy() == "complete") {
int64_t covered = 0;
int64_t logicalElements = 1;
for (int64_t dimSize : logicalType.getShape()) {
if (ShapedType::isDynamic(dimSize))
return emitError("fragment assembly complete coverage requires static result shape");
logicalElements *= dimSize;
}
for (const auto& [ignoredOffsets, fragmentSizes] : slices) {
int64_t fragmentElements = 1;
for (int64_t dimSize : fragmentSizes)
fragmentElements *= dimSize;
covered += fragmentElements;
}
if (covered != logicalElements)
return emitError("fragment assembly complete coverage must cover the whole result exactly");
}
return success();
}
LogicalResult SpatMaterializeLayoutOp::verify() {
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
return failure();
if (!isKnownLogicalLayout(getLogicalLayout()))
return emitError("requires a known logical layout");
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
return emitError("requires a known source physical layout");
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
return emitError("requires a known target physical layout");
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isAnySpatialComputeLike(op))
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !isAnySpatialComputeLike(op->getParentOp());
});
})) {
return op->emitError("compute result used directly inside another Spatial compute body");
}
return success();
}
template <typename ComputeOpTy>
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
auto& block = compute.getBody().front();
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return compute.emitOpError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto blockArg = compute.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
auto blockArg = compute.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
return compute.emitOpError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
return compute.emitOpError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto resultTypes = compute.getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size())
return emitError("ComputeOp must have same number of results as yieldOp operands");
return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
return emitError("ComputeOp output must be of the same type as yieldOp operand");
return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
}
else {
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
}
}
else if (dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
}
}
}
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(*this, "compute")))
for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return compute.emitOpError("ComputeOp block argument is not used");
if (failed(verifyStaticWeights(compute, opName)))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation())))
if (failed(verifyComputeResultsUses(compute.getOperation())))
return failure();
return success();
}
LogicalResult SpatComputeBatch::verify() {
int32_t count = getLaneCount();
LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
template <typename ComputeBatchOpTy>
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
int32_t count = batch.getLaneCount();
if (count <= 0)
return emitError("laneCount must be positive");
return batch.emitOpError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
if (!coreIdsAttr)
return emitError("compute_batch coreIds attribute must be a dense i32 array");
return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount");
return batch.emitOpError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative");
return batch.emitOpError("compute_batch coreIds values must be non-negative");
DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be unique");
return batch.emitOpError("compute_batch coreIds values must be unique");
}
Block& block = getBody().front();
Block& block = batch.getBody().front();
if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
return batch.emitOpError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = getLaneArgument();
return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
auto laneArg = batch.getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type");
return batch.emitOpError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
auto blockArg = getWeightArgument(weightIndex);
for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
auto blockArg = batch.getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
auto blockArg = getInputArgument(inputIndex);
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
auto blockArg = batch.getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
auto blockArg = getOutputArgument(resultIndex);
for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
auto blockArg = batch.getOutputArgument(resultIndex);
if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
return batch.emitOpError("compute_batch output block argument types must match result types exactly");
}
if (failed(verifyComputeResultsUses(this->getOperation())))
if (failed(verifyComputeResultsUses(batch.getOperation())))
return failure();
if (failed(verifyStaticWeights(*this, "compute_batch")))
if (failed(verifyStaticWeights(batch, opName)))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
return failure();
return verifyBatchBody(*this, block);
return verifyBatchBody(batch, block);
}
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
LogicalResult SpatScheduledComputeBatch::verify() {
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
Operation* parent = getOperation()->getParentOp();
if (!isAnySpatialComputeBatchLike(parent))
return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
if (parent->getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
auto laneArg = batchOp.getLaneArgument();
std::optional<BlockArgument> laneArg;
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
laneArg = graphBatch.getLaneArgument();
else
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
if (!laneArg)
return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) {
@@ -494,7 +785,10 @@ LogicalResult SpatInParallelOp::verify() {
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
if ((isa<SpatGraphComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|| (isa<SpatScheduledComputeBatch>(parent)
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,128 @@
--- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000
+++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000
@@ -4112,104 +4112,8 @@
Value originalOutput,
Location loc);
-FailureOr<SmallVector<OpFoldResult, 4>> rematerializeProjectionIndexListForBatchHostOutput(
- MaterializerState& state,
- MaterializedClass& sourceClass,
- ArrayRef<OpFoldResult> values,
- IRMapping& mapper,
- Location loc) {
- SmallVector<OpFoldResult, 4> localized;
- localized.reserve(values.size());
- for (OpFoldResult value : values) {
- FailureOr<OpFoldResult> remapped =
- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper);
- if (failed(remapped))
- return failure();
- localized.push_back(*remapped);
- }
- return localized;
-}
-
-LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state,
- MaterializedClass& sourceClass,
- Value originalOutput,
- Value payload,
- Value destination,
- ArrayRef<ProducerKey> keys,
- Location loc) {
- auto originalResult = dyn_cast<OpResult>(originalOutput);
- if (!originalResult)
- return failure();
-
- auto sourceBatch = dyn_cast_or_null<SpatComputeBatch>(originalResult.getOwner());
- if (!sourceBatch || sourceBatch.getNumResults() == 0)
- return failure();
-
- FailureOr<tensor::ParallelInsertSliceOp> projection =
- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber());
- if (failed(projection))
- return failure();
-
- auto sourceLaneArg = sourceBatch.getLaneArgument();
- if (!sourceLaneArg)
- return failure();
-
- auto materializedBatch = dyn_cast<SpatScheduledComputeBatch>(sourceClass.op);
- if (!materializedBatch)
- return failure();
-
- auto materializedLaneArg = materializedBatch.getLaneArgument();
- if (!materializedLaneArg)
- return failure();
-
- if (keys.size() != sourceClass.cpus.size())
- return failure();
-
- SmallVector<int64_t, 8> logicalLanes;
- logicalLanes.reserve(keys.size());
- for (ProducerKey key : keys) {
- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber())
- return failure();
- logicalLanes.push_back(key.instance.laneStart);
- }
-
- IRMapping mapper;
- Value logicalLane = createIndexedIndexValue(state,
- sourceClass.op,
- ArrayRef<int64_t>(logicalLanes),
- *materializedLaneArg,
- loc,
- static_cast<int64_t>(sourceClass.cpus.size()),
- /*allowExhaustiveTiledSearch=*/false);
- mapper.map(*sourceLaneArg, logicalLane);
-
- FailureOr<SmallVector<OpFoldResult, 4>> offsets =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedOffsets(), mapper, loc);
- if (failed(offsets))
- return failure();
- FailureOr<SmallVector<OpFoldResult, 4>> sizes =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedSizes(), mapper, loc);
- if (failed(sizes))
- return failure();
- FailureOr<SmallVector<OpFoldResult, 4>> strides =
- rematerializeProjectionIndexListForBatchHostOutput(
- state, sourceClass, projection->getMixedStrides(), mapper, loc);
- if (failed(strides))
- return failure();
-
- tensor::ParallelInsertSliceOp::create(
- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides);
- return success();
-}
-
LogicalResult
-setHostOutputValue(MaterializerState& state,
- MaterializedClass& sourceClass,
- Value originalOutput,
- Value payload,
- ArrayRef<ProducerKey> keys = {}) {
+setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
if (resultIt == sourceClass.hostOutputToResultIndex.end())
return sourceClass.op->emitError("missing host result slot for materialized output")
@@ -4253,10 +4157,6 @@
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
- if (succeeded(createProjectionAwareBatchHostInsert(
- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc())))
- return success();
-
createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg);
return success();
}
@@ -4276,7 +4176,7 @@
MaterializedClass& ownerClass = state.classes[ownerIt->second];
if (sourceClass.id == ownerClass.id)
- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys);
+ return setHostOutputValue(state, ownerClass, originalOutput, payload);
// Keep the old deadlock-free communication discipline: only scalar-to-scalar
// host-owner forwarding is introduced here. Batch host publication remains on
@@ -40,11 +40,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getProducerValueRef;
using SpatCompute = spatial::SpatGraphCompute;
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
if (failed(checkedCoreId))
@@ -187,32 +186,50 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
};
//TODO Used for report refactor
struct CollectorConcatRow {
uint64_t computeId = 0;
int32_t coreId = -1;
uint64_t operandCount = 0;
};
uint64_t totalComputeOps = 0;
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalCrossbarCount = 0;
uint64_t nextBatchId = 0;
//TODO Used for report refactor
std::vector<ReportRow> collectedData;
//TODO Used for report refactor
std::vector<CollectorConcatRow> collectorConcatRows;
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
};
for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t computeId = totalComputeOps++;
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t maxConcatOperands = 0;
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
});
//TODO 128 is a magic number
if (maxConcatOperands >= 128 && !coreIds.empty())
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalCrossbarCount += perInstanceCrossbarCount;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
if (!collectedData.empty() || !collectorConcatRows.empty())
os << "\n";
if (!collectorConcatRows.empty()) {
os << "Collector concat materialization:\n";
for (const CollectorConcatRow& row : collectorConcatRows)
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
os << "\n";
}
sortReportEntriesByFirstCore(collectedData);
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
@@ -328,7 +353,17 @@ public:
void runOnOperation() override {
func::FuncOp func = getOperation();
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed at the start of MergeComputeNodes");
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(func);
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after trivial merge simplification");
signalPassFailure();
return;
}
const spatial::MergeScheduleResult* analysisResult = nullptr;
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
@@ -342,8 +377,8 @@ public:
signalPassFailure();
return;
}
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
if (failed(verifyScheduledSpatialInvariants(func))) {
func.emitOpError("RAPTOR_PHASE_CHECK scheduled Spatial verification failed after merge materialization");
signalPassFailure();
return;
}
@@ -12,6 +12,7 @@
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <cmath>
#include <iterator>
#include <limits>
#include <optional>
@@ -21,6 +22,7 @@
#include "ComputeGraph.hpp"
#include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Support/TypeUtilities.hpp"
@@ -35,9 +37,223 @@ uint64_t countComputeBodyOperationInstances(Region& body);
namespace {
Cost getComputeBodyCost(Region& body) {
constexpr Cost kOperationCost = 100;
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
struct PimsimSchedulerCostModel {
static constexpr Cost kDefaultBitwidth = 8;
static constexpr Cost kCorePeriodNs = 1;
static constexpr Cost kLocalMemoryWidthBytes = 64;
static constexpr Cost kLocalMemoryLatencyCycles = 1;
static constexpr Cost kNetworkBusWidthBytes = 8;
static constexpr Cost kNetworkBaseLatencyNs = 2;
static constexpr Cost kNetworkPerHopLatencyNs = 1;
static constexpr Cost kVectorWidth = 16;
static constexpr Cost kVectorLatencyCycles = 4;
static constexpr Cost kDacResolutionBits = 1;
static constexpr Cost kDacLatencyCycles = 1;
static constexpr Cost kDacCount = 128;
static constexpr Cost kXbarReadLatencyNs = 30;
static constexpr Cost kSampleHoldLatencyCycles = 1;
static constexpr Cost kAdcLatencyCycles = 10;
static constexpr Cost kAdcCount = 2;
static constexpr Cost kShiftAdderLatencyCycles = 1;
static constexpr Cost kOutputBufferLatencyCycles = 1;
static constexpr Cost kInputBufferLatencyCycles = 0;
static constexpr Cost kFallbackOperationCost = 1;
static Cost ceilDiv(Cost numerator, Cost denominator) {
assert(denominator > 0 && "denominator must be positive");
return (numerator + denominator - 1) / denominator;
}
static std::optional<Cost> getStaticElementCount(Type type) {
auto shaped = dyn_cast<ShapedType>(type);
if (!shaped || !shaped.hasStaticShape())
return std::nullopt;
return static_cast<Cost>(shaped.getNumElements());
}
static Cost getBitwidthOrDefault(Type type) {
if (auto shaped = dyn_cast<ShapedType>(type))
type = shaped.getElementType();
if (auto intType = dyn_cast<IntegerType>(type))
return intType.getWidth();
if (auto floatType = dyn_cast<FloatType>(type))
return floatType.getWidth();
if (isa<IndexType>(type))
return 64;
return kDefaultBitwidth;
}
static Cost getComputeBitwidth(Type type) {
return std::min(getBitwidthOrDefault(type), kDefaultBitwidth);
}
static Cost getByteSize(Type type, Cost fallbackBitwidth = kDefaultBitwidth) {
auto elementCount = getStaticElementCount(type);
if (!elementCount)
return kFallbackOperationCost;
Cost bitwidth = fallbackBitwidth;
if (bitwidth <= 0)
bitwidth = getBitwidthOrDefault(type);
if (bitwidth <= 0)
bitwidth = kDefaultBitwidth;
return ceilDiv(checkedMultiply(*elementCount, bitwidth), static_cast<Cost>(8));
}
static Cost getVectorReadWriteCost(Cost readBytes, Cost writeBytes) {
Cost totalBytes = checkedAdd(readBytes, writeBytes);
return checkedMultiply(ceilDiv(totalBytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
}
static Cost getVectorComputeCost(Cost elementCount) {
return checkedMultiply(ceilDiv(elementCount, kVectorWidth),
checkedMultiply(kVectorLatencyCycles, kCorePeriodNs));
}
static Cost getTensorMoveCost(Type type) {
return getVectorReadWriteCost(getByteSize(type), 0);
}
static std::pair<Cost, Cost> estimateMeshShape() {
Cost coreCount = static_cast<Cost>(std::max<long>(1, coresCount.getValue()));
Cost rows = static_cast<Cost>(std::sqrt(static_cast<long double>(coreCount)));
if (rows == 0)
rows = 1;
while (rows > 1 && coreCount % rows != 0)
--rows;
Cost cols = ceilDiv(coreCount, rows);
return {rows, cols};
}
static Cost getAverageInterCoreLatencyNs() {
auto [rows, cols] = estimateMeshShape();
auto averageAxisDistance = [](Cost size) -> Cost {
if (size <= 1)
return 0;
return checkedMultiply(size, size) - 1;
};
Cost avgRow = averageAxisDistance(rows) / (static_cast<Cost>(3) * rows);
Cost avgCol = averageAxisDistance(cols) / (static_cast<Cost>(3) * cols);
return checkedAdd(kNetworkBaseLatencyNs, checkedMultiply(kNetworkPerHopLatencyNs, checkedAdd(avgRow, avgCol)));
}
static Cost getInterCoreTransferCostFromBytes(Cost bytes) {
Cost localRead = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
Cost localWrite = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
Cost payloadFlits = ceilDiv(bytes, kNetworkBusWidthBytes);
Cost averageNoCLatency = getAverageInterCoreLatencyNs();
Cost network = checkedMultiply(checkedAdd(static_cast<Cost>(2), payloadFlits), averageNoCLatency);
return checkedAdd(checkedAdd(localRead, localWrite), network);
}
static Cost getUnaryVectorCost(Type inputType, Type outputType, bool scalarOutput = false) {
auto maybeElements = getStaticElementCount(inputType);
if (!maybeElements)
return kFallbackOperationCost;
Cost inputBytes = getByteSize(inputType, getComputeBitwidth(inputType));
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
: getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getVectorComputeCost(*maybeElements));
}
static Cost getBinaryVectorCost(Type lhsType, Type rhsType, Type outputType, bool scalarOutput = false) {
auto maybeElements = getStaticElementCount(lhsType);
if (!maybeElements)
return kFallbackOperationCost;
Cost readBytes = checkedAdd(getByteSize(lhsType, getComputeBitwidth(lhsType)),
getByteSize(rhsType, getComputeBitwidth(rhsType)));
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
: getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(readBytes, outputBytes), getVectorComputeCost(*maybeElements));
}
static Cost getMatrixComputeLatency(Cost inputBitwidth) {
Cost xbarDim = static_cast<Cost>(crossbarSize.getValue());
Cost inputTimes = ceilDiv(inputBitwidth, kDacResolutionBits);
Cost dacTimes = ceilDiv(xbarDim, kDacCount);
Cost adcTimes = ceilDiv(xbarDim, kAdcCount);
Cost frontStage = kInputBufferLatencyCycles + kDacLatencyCycles + kXbarReadLatencyNs + kSampleHoldLatencyCycles;
Cost backPipe = std::max(kAdcLatencyCycles, checkedAdd(kShiftAdderLatencyCycles, kOutputBufferLatencyCycles));
Cost backStage = checkedAdd(checkedAdd(kAdcLatencyCycles, kShiftAdderLatencyCycles), kOutputBufferLatencyCycles);
backStage = checkedAdd(backStage, checkedMultiply(adcTimes - 1, backPipe));
Cost totalTimes = checkedMultiply(inputTimes, dacTimes);
Cost stagePipe = std::max(frontStage, backStage);
return checkedAdd(checkedAdd(frontStage, backStage),
checkedMultiply(totalTimes - 1, stagePipe));
}
static Cost getWvmmCost(Type inputType, Type outputType) {
Cost inputBitwidth = getComputeBitwidth(inputType);
Cost inputBytes = checkedMultiply(static_cast<Cost>(crossbarSize.getValue()),
ceilDiv(inputBitwidth, static_cast<Cost>(8)));
inputBytes = checkedMultiply(inputBytes, static_cast<Cost>(8));
Cost outputBytes = getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getMatrixComputeLatency(inputBitwidth));
}
};
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop);
[[maybe_unused]] Cost getOperationCost(Operation& op);
[[maybe_unused]] Cost getRegionCost(Region& body) {
Cost cost = 0;
for (Block& block : body)
for (Operation& op : block)
cost = checkedAdd(cost, getOperationCost(op));
return cost;
}
[[maybe_unused]] Cost getOperationCost(Operation& op) {
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
if (!tripCount)
return PimsimSchedulerCostModel::kFallbackOperationCost;
return checkedMultiply(getRegionCost(loop.getRegion()), static_cast<Cost>(*tripCount));
}
if (isa<SpatYieldOp, SpatInParallelOp, affine::AffineApplyOp, arith::ConstantOp,
tensor::EmptyOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(&op))
return 0;
if (auto wvmm = dyn_cast<SpatVMMOp>(&op))
return PimsimSchedulerCostModel::getWvmmCost(wvmm.getInput().getType(), wvmm.getOutput().getType());
if (auto vvdmul = dyn_cast<SpatVVDMulOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(
vvdmul.getLhs().getType(), vvdmul.getRhs().getType(), vvdmul.getOutput().getType(), /*scalarOutput=*/true);
if (auto vadd = dyn_cast<SpatVAddOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vadd.getLhs().getType(), vadd.getRhs().getType(),
vadd.getOutput().getType());
if (auto vsub = dyn_cast<SpatVSubOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vsub.getLhs().getType(), vsub.getRhs().getType(),
vsub.getOutput().getType());
if (auto vmul = dyn_cast<SpatVMulOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vmul.getLhs().getType(), vmul.getRhs().getType(),
vmul.getOutput().getType());
if (auto vmax = dyn_cast<SpatVMaxOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vmax.getLhs().getType(), vmax.getRhs().getType(),
vmax.getOutput().getType());
if (auto vavg = dyn_cast<SpatVAvgOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(vavg.getInput().getType(), vavg.getOutput().getType(),
/*scalarOutput=*/true);
if (auto relu = dyn_cast<SpatReluOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(relu.getInput().getType(), relu.getOutput().getType());
if (auto sigm = dyn_cast<SpatSigmoidOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(sigm.getInput().getType(), sigm.getOutput().getType());
if (auto softmax = dyn_cast<SpatSoftmaxOp>(&op)) {
Cost unary = PimsimSchedulerCostModel::getUnaryVectorCost(softmax.getInput().getType(), softmax.getOutput().getType());
return checkedMultiply(unary, static_cast<Cost>(4));
}
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op))
return PimsimSchedulerCostModel::getTensorMoveCost(extract.getResult().getType());
if (auto insert = dyn_cast<tensor::InsertSliceOp>(&op))
return PimsimSchedulerCostModel::getTensorMoveCost(insert.getSource().getType());
Cost nestedCost = 0;
for (Region& region : op.getRegions())
nestedCost = checkedAdd(nestedCost, getRegionCost(region));
return checkedAdd(PimsimSchedulerCostModel::kFallbackOperationCost, nestedCost);
}
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
@@ -54,6 +270,11 @@ std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
return (distance + stride - 1) / stride;
}
Cost getComputeBodyCost(Region& body) {
constexpr Cost kOperationCost = 100;
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
}
uint64_t countOperationInstances(Operation& op) {
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
@@ -149,7 +370,8 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return std::nullopt;
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
projectedCost = checkedAdd(
projectedCost, PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(resultType))));
}
if (projectedCost == 0)
@@ -162,7 +384,7 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input)
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
return *projectedCost;
return static_cast<Cost>(getSizeInBytes(inputType));
return PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(inputType)));
}
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
@@ -451,7 +673,7 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
continue;
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
inserted.first->second = checkedAdd(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
void verifySchedule(const ComputeGraph& graph,
const MergeScheduleResult& result,
unsigned long crossbarCapacity,
size_t processorCount) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
earliestTargetStart = addOrMax(
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
static_cast<unsigned long>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
verifySchedule(graph,
schedule,
static_cast<unsigned long>(crossbarCountInCore.getValue()),
options.processorCount);
return schedule;
}
@@ -4,6 +4,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cmath>
#include <limits>
#include <queue>
#include <vector>
@@ -21,6 +22,63 @@ struct ScheduledTask {
Time endTime = 0;
};
struct MeshModel {
size_t rows = 1;
size_t cols = 1;
long double averageDistance = 0.0L;
static MeshModel infer(size_t processorCount) {
MeshModel model;
if (processorCount == 0)
return model;
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
if (model.rows == 0)
model.rows = 1;
while (model.rows > 1 && processorCount % model.rows != 0)
--model.rows;
model.cols = (processorCount + model.rows - 1) / model.rows;
auto averageAxisDistance = [](size_t size) -> long double {
if (size <= 1)
return 0.0L;
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
};
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
return model;
}
std::pair<size_t, size_t> getCoord(size_t processor) const {
return {processor / cols, processor % cols};
}
size_t getDistance(size_t lhs, size_t rhs) const {
auto [lhsRow, lhsCol] = getCoord(lhs);
auto [rhsRow, rhsCol] = getCoord(rhs);
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
return rowDistance + colDistance;
}
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
if (sourceProcessor == targetProcessor || transferCost == 0)
return 0;
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
scale = std::max(0.25L, scale);
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
}
size_t getCenterDistance(size_t processor) const {
auto [row, col] = getCoord(processor);
size_t centerRow = rows / 2;
size_t centerCol = cols / 2;
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
return rowDistance + colDistance;
}
};
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
std::queue<size_t> readySinks;
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
} // namespace
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
}
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
const size_t nodeCount = graph.nodes.size();
const size_t processorCount = options.processorCount;
if (processorCount == 0)
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
MeshModel mesh = MeshModel::infer(processorCount);
verifyOctTableSize(nodeCount, processorCount);
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
// MOCK: Replace this with your actual heterogeneous cost lookup.
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
unsigned int bestOverlapCount = 0;
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
bool crossbarRejected = false;
for (size_t processor = 0; processor < processorCount; ++processor) {
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time dataReady = 0;
for (const auto& [pred, comm] : graph.predecessors[task]) {
const ScheduledTask& predSchedule = schedules[pred];
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
}
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time eft = addOrMax(est, computeCost);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
size_t centerDistance = mesh.getCenterDistance(processor);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance < bestCenterDistance) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
}
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
mlir::MLIRContext* context = nullptr;
};
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
} // namespace spatial
+2
View File
@@ -8,6 +8,8 @@
namespace onnx_mlir {
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<mlir::Pass> createSpatialLayoutPlanningPass();
std::unique_ptr<mlir::Pass> createLowerSpatialPlansPass();
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
+2
View File
@@ -72,6 +72,8 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
registerPass(createONNXToSpatialPass);
registerPass(createSpatialLayoutPlanningPass);
registerPass(createLowerSpatialPlansPass);
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass);
+9 -9
View File
@@ -6,15 +6,15 @@ from onnx import TensorProto
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
DTYPES = {
TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"),
TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"),
TensorProto.INT64: ("int64_t", "%lld","ONNX_TYPE_INT64"),
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
TensorProto.FLOAT: ("float", "%.9g", "ONNX_TYPE_FLOAT"),
TensorProto.DOUBLE: ("double", "%.17g", "ONNX_TYPE_DOUBLE"),
TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"),
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"),
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"),
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
}
def esc(s): return s.replace("\\","\\\\").replace('"','\\"')
@@ -0,0 +1,295 @@
#!/usr/bin/env python3.13
import argparse
import math
import subprocess
import sys
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
SCRIPT_DIR = Path(__file__).resolve().parent
VALIDATION_DIR = SCRIPT_DIR.parent
REPO_ROOT = VALIDATION_DIR.parent
if str(VALIDATION_DIR) not in sys.path:
sys.path.insert(0, str(VALIDATION_DIR))
from onnx_utils import _ONNX_TO_NP, onnx_io, write_inputs_to_memory_bin
from validate_one import (
MODE_COMPILE_ONLY,
build_dump_ranges,
parse_pim_simulator_outputs,
run_pim_simulator,
sanitize_output_name,
validate_network,
)
from yolo_real_image_validation import save_tensor_csv
IMAGENET_MEAN = np.asarray([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.asarray([0.229, 0.224, 0.225], dtype=np.float32)
DEFAULT_VGG_MODEL = VALIDATION_DIR / "networks" / "vgg16" / "depth_35" / "vgg16_depth_35.onnx"
DEFAULT_RESNET_MODEL = VALIDATION_DIR / "networks" / "resnet" / "resnet18_torchvision.onnx"
def resolve_default_paths():
return {
"raptor_path": REPO_ROOT / "build_release" / "Release" / "bin" / "onnx-mlir",
"onnx_include_dir": REPO_ROOT / "onnx-mlir" / "include",
"simulator_dir": REPO_ROOT / "backend-simulators" / "pim" / "pim-simulator",
}
def resolve_model_path(network: str | None, model: Path | None) -> Path:
if model is not None:
return model.resolve()
if network == "resnet":
return DEFAULT_RESNET_MODEL.resolve()
if network == "vgg":
return DEFAULT_VGG_MODEL.resolve()
raise SystemExit("Pass --model or select a default with --network {resnet,vgg}.")
def ensure_local_artifacts(args, model_path: Path):
validate_network(
network_onnx_path=model_path,
raptor_path=args.raptor_path,
onnx_include_dir=args.onnx_include_dir,
simulator_dir=args.simulator_dir,
crossbar_size=args.crossbar_size,
crossbar_count=args.crossbar_count,
core_count=args.core_count,
command_timeout_seconds=args.command_timeout_seconds,
mode=MODE_COMPILE_ONLY,
verbose=args.verbose,
)
def ensure_existing_artifacts(model_dir: Path):
required_paths = [
model_dir / "runner" / "build" / "runner",
model_dir / "raptor" / "pim" / "config.json",
model_dir / "raptor" / "pim" / "memory.bin",
]
missing = [str(path) for path in required_paths if not path.exists()]
if missing:
raise FileNotFoundError(
"Missing compiled local artifacts. Re-run without --skip-compile or restore these paths:\n "
+ "\n ".join(missing)
)
def preprocess_classification_image(image_path: Path) -> tuple[Image.Image, np.ndarray]:
image = Image.open(image_path).convert("RGB")
width, height = image.size
scale = 256.0 / min(width, height)
resized_size = (
max(1, int(round(width * scale))),
max(1, int(round(height * scale))),
)
resized = image.resize(resized_size, Image.Resampling.BILINEAR)
left = (resized.width - 224) // 2
top = (resized.height - 224) // 2
cropped = resized.crop((left, top, left + 224, top + 224))
array = np.asarray(cropped, dtype=np.float32) / 255.0
array = (array - IMAGENET_MEAN) / IMAGENET_STD
chw = np.transpose(array, (2, 0, 1))
tensor = np.expand_dims(chw.astype(np.float32, copy=False), axis=0)
return image, tensor
def load_labels(labels_path: Path | None) -> list[str] | None:
if labels_path is None:
return None
labels = [line.strip() for line in labels_path.read_text().splitlines()]
return labels or None
def softmax(values: np.ndarray) -> np.ndarray:
shifted = values - np.max(values)
exp = np.exp(shifted)
denom = exp.sum()
if not math.isfinite(float(denom)) or denom <= 0.0:
raise RuntimeError("Softmax received non-finite output scores.")
return exp / denom
def decode_classification_output(output: np.ndarray, labels: list[str] | None, top_k: int):
scores = np.asarray(output, dtype=np.float64).reshape(-1)
probabilities = softmax(scores)
limit = min(top_k, probabilities.size)
top_indices = np.argsort(probabilities)[-limit:][::-1]
results = []
for index in top_indices:
label = None
if labels is not None and 0 <= int(index) < len(labels):
label = labels[int(index)]
results.append(
{
"index": int(index),
"label": label,
"probability": float(probabilities[int(index)]),
}
)
return results
def render_result_line(result) -> str:
name = result["label"] if result["label"] else f'class {result["index"]}'
return f'{name}: {result["probability"] * 100.0:.2f}%'
def draw_classification_panel(image: Image.Image, results, output_path: Path):
annotated = image.copy()
draw = ImageDraw.Draw(annotated)
lines = [render_result_line(result) for result in results]
if not lines:
lines = ["No predictions"]
padding = 10
line_gap = 4
max_width = 0
line_heights = []
for line in lines:
left, top, right, bottom = draw.textbbox((0, 0), line)
max_width = max(max_width, right - left)
line_heights.append(bottom - top)
panel_height = padding * 2 + sum(line_heights) + line_gap * (len(lines) - 1)
panel_width = padding * 2 + max_width
origin_x = 12
origin_y = 12
draw.rounded_rectangle(
(origin_x, origin_y, origin_x + panel_width, origin_y + panel_height),
radius=10,
fill=(0, 0, 0),
)
y = origin_y + padding
for line, line_height in zip(lines, line_heights):
draw.text((origin_x + padding, y), line, fill=(255, 255, 255))
y += line_height + line_gap
annotated.save(output_path)
def run_reference_and_simulator(args, model_path: Path, tensor: np.ndarray):
model_dir = model_path.parent
runner_build_dir = model_dir / "runner" / "build"
runner_path = runner_build_dir / "runner"
pim_dir = model_dir / "raptor" / "pim"
simulation_dir = model_dir / "classification_demo" / "simulation"
reference_dir = model_dir / "classification_demo" / "reference"
inputs_dir = model_dir / "classification_demo" / "inputs"
simulation_dir.mkdir(parents=True, exist_ok=True)
reference_dir.mkdir(parents=True, exist_ok=True)
inputs_dir.mkdir(parents=True, exist_ok=True)
input_descriptors, output_descriptors = onnx_io(model_path)
if len(input_descriptors) != 1:
raise RuntimeError(f"Expected one classification input tensor, found {len(input_descriptors)}")
if len(output_descriptors) != 1:
raise RuntimeError(f"Expected one classification output tensor, found {len(output_descriptors)}")
input_index, _input_name, _input_dtype, input_shape = input_descriptors[0]
if list(tensor.shape) != list(input_shape):
raise RuntimeError(f"Preprocessed tensor shape {list(tensor.shape)} does not match model input {input_shape}")
input_csv = inputs_dir / "in0.csv"
save_tensor_csv(tensor, input_csv)
runner_cmd = [
str(runner_path),
f"--in{input_index}-csv-file",
str(input_csv),
f"--in{input_index}-shape",
"x".join(str(dim) for dim in tensor.shape),
"--save-csv-dir",
str(reference_dir),
]
subprocess.run(runner_cmd, cwd=runner_build_dir, check=True)
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", [tensor])
dump_ranges = build_dump_ranges(pim_dir / "config.json", output_descriptors)
output_bin_path = simulation_dir / "out.bin"
run_pim_simulator(
args.simulator_dir,
pim_dir,
output_bin_path,
dump_ranges,
timeout_sec=args.command_timeout_seconds,
)
output_index, output_name, output_dtype_code, output_shape = output_descriptors[0]
output_dtype = np.dtype(_ONNX_TO_NP[output_dtype_code])
reference_csv = reference_dir / f"output{output_index}_{sanitize_output_name(output_name)}.csv"
reference_output = np.loadtxt(reference_csv, delimiter=",", dtype=output_dtype).reshape(output_shape)
simulator_output = parse_pim_simulator_outputs(output_bin_path, output_descriptors)[0]
return reference_output, simulator_output
def print_topk(title: str, results):
print(title)
for rank, result in enumerate(results, start=1):
label_text = result["label"] if result["label"] else f'class {result["index"]}'
print(f' {rank}. {label_text} ({result["probability"] * 100.0:.2f}%) [index={result["index"]}]')
def main():
defaults = resolve_default_paths()
parser = argparse.ArgumentParser(description="Run a VGG or ResNet ONNX model through the Raptor simulator and annotate the image with top classification results.")
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--network", choices=("resnet", "vgg"), default=None)
parser.add_argument("--image", type=Path, required=True)
parser.add_argument("--labels", type=Path, default=None)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--raptor-path", type=Path, default=defaults["raptor_path"])
parser.add_argument("--onnx-include-dir", type=Path, default=defaults["onnx_include_dir"])
parser.add_argument("--simulator-dir", type=Path, default=defaults["simulator_dir"])
parser.add_argument("--crossbar-size", type=int, default=2048)
parser.add_argument("--crossbar-count", type=int, default=256)
parser.add_argument("--core-count", type=int, default=1000)
parser.add_argument("--top-k", type=int, default=5)
parser.add_argument("--command-timeout-seconds", type=float, default=7200.0)
parser.add_argument("--skip-compile", action="store_true")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
args.model = resolve_model_path(args.network, args.model)
args.image = args.image.resolve()
args.output = args.output.resolve()
args.labels = args.labels.resolve() if args.labels else None
args.raptor_path = args.raptor_path.resolve()
args.onnx_include_dir = args.onnx_include_dir.resolve()
args.simulator_dir = args.simulator_dir.resolve()
if not args.skip_compile:
ensure_local_artifacts(args, args.model)
else:
ensure_existing_artifacts(args.model.parent)
original_image, tensor = preprocess_classification_image(args.image)
labels = load_labels(args.labels)
reference_output, simulator_output = run_reference_and_simulator(args, args.model, tensor)
reference_results = decode_classification_output(reference_output, labels, args.top_k)
simulator_results = decode_classification_output(simulator_output, labels, args.top_k)
print_topk("Reference top-k:", reference_results)
print_topk("Simulator top-k:", simulator_results)
reference_scores = np.asarray(reference_output, dtype=np.float64).reshape(-1)
simulator_scores = np.asarray(simulator_output, dtype=np.float64).reshape(-1)
max_abs_diff = float(np.max(np.abs(reference_scores - simulator_scores)))
print(f"Max absolute score diff: {max_abs_diff:.6e}")
args.output.parent.mkdir(parents=True, exist_ok=True)
draw_classification_panel(original_image, simulator_results, args.output)
print(f"Annotated image saved to {args.output}")
if __name__ == "__main__":
main()
File diff suppressed because it is too large Load Diff