Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 78e97f9fd8 | |||
| 984f362623 | |||
| 568fd90542 | |||
| be0bcc9dcc | |||
| 62dd40ee89 | |||
| 2b4115699a | |||
| 3a985b3675 | |||
| 4ab24eb288 | |||
| e083c27d80 |
@@ -6,6 +6,7 @@
|
|||||||
* Always try the release build first before building with the debug version
|
* Always try the release build first before building with the debug version
|
||||||
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
|
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
|
||||||
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
|
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
|
||||||
|
* Always prepend rtk to shell commands if missing and if rtk is available
|
||||||
|
|
||||||
# Core engineering philosophy
|
# Core engineering philosophy
|
||||||
|
|
||||||
|
|||||||
@@ -258,24 +258,23 @@ where
|
|||||||
|
|
||||||
let (memory, crossbars) = core.get_memory_crossbar();
|
let (memory, crossbars) = core.get_memory_crossbar();
|
||||||
let crossbar = crossbars.get_mut(group).unwrap();
|
let crossbar = crossbars.get_mut(group).unwrap();
|
||||||
let crossbar_stored_bytes = crossbar.stored_bytes();
|
|
||||||
let crossbar_byte_width = crossbar.width();
|
|
||||||
|
|
||||||
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
|
||||||
ensure!(
|
|
||||||
crossbar_byte_width % size_of::<M>() == 0,
|
|
||||||
"M not divisor of the crosbbar size"
|
|
||||||
);
|
|
||||||
|
|
||||||
let crossbar_height = crossbar.height();
|
let crossbar_height = crossbar.height();
|
||||||
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
let crossbar_stored_bytes = crossbar.stored_bytes();
|
||||||
|
let bytes_per_column = crossbar_height * size_of::<M>();
|
||||||
|
ensure!(bytes_per_column != 0, "crossbar height can not be zero");
|
||||||
|
ensure!(
|
||||||
|
crossbar_stored_bytes % bytes_per_column == 0,
|
||||||
|
"Stored crossbar bytes do not describe an integral number of columns"
|
||||||
|
);
|
||||||
|
let crossbar_elem_width = crossbar_stored_bytes / bytes_per_column;
|
||||||
|
ensure!(crossbar_elem_width != 0, "Crossbar contains no stored columns");
|
||||||
|
|
||||||
let loads = memory
|
let loads = memory
|
||||||
.reserve_load(r1_val, crossbar_height * size_of::<F>())?
|
.reserve_load(r1_val, crossbar_height * size_of::<F>())?
|
||||||
.execute_load::<F>()?;
|
.execute_load::<F>()?;
|
||||||
let load = loads[0];
|
let load = loads[0];
|
||||||
let vec: Cow<[M]> = load.up();
|
let vec: Cow<[M]> = load.up();
|
||||||
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
let matrix = crossbar.load::<M>(crossbar_stored_bytes)?[0];
|
||||||
|
|
||||||
// --- FAER IMPLEMENTATION ---
|
// --- FAER IMPLEMENTATION ---
|
||||||
|
|
||||||
|
|||||||
Submodule backend-simulators/pim/pimsim-nn updated: 6d3b898e6b...3e3442b663
+1
-1
Submodule onnx-mlir updated: eb54c2afc4...82018d7ce5
@@ -56,6 +56,22 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
|||||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (result) {
|
||||||
|
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
|
if (yieldOp && result.getResultNumber() < yieldOp.getNumOperands()) {
|
||||||
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size())
|
||||||
|
return resolveLoopCarriedAliasImpl(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
}
|
||||||
|
return yieldedValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||||
@@ -512,6 +528,24 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto condition = resolveIndexValueImpl(ifOp.getCondition(), knowledge);
|
||||||
|
if (failed(condition))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
mlir::Region& selectedRegion = *condition != 0 ? ifOp.getThenRegion() : ifOp.getElseRegion();
|
||||||
|
auto yieldOp = mlir::dyn_cast<mlir::scf::YieldOp>(selectedRegion.front().getTerminator());
|
||||||
|
if (!yieldOp || result.getResultNumber() >= yieldOp.getNumOperands())
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
value = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
@@ -622,6 +656,33 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto ifOp = mlir::dyn_cast<mlir::scf::IfOp>(definingOp)) {
|
||||||
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto thenYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
|
||||||
|
auto elseYield = mlir::dyn_cast<mlir::scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
|
||||||
|
if (!thenYield || !elseYield || result.getResultNumber() >= thenYield.getNumOperands()
|
||||||
|
|| result.getResultNumber() >= elseYield.getNumOperands()) {
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto thenAddress = compileContiguousAddressExprImpl(thenYield.getOperand(result.getResultNumber()));
|
||||||
|
auto elseAddress = compileContiguousAddressExprImpl(elseYield.getOperand(result.getResultNumber()));
|
||||||
|
if (failed(thenAddress) || failed(elseAddress) || thenAddress->base != elseAddress->base)
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
auto condition = compileIndexValueImpl(ifOp.getCondition());
|
||||||
|
if (failed(condition))
|
||||||
|
return mlir::failure();
|
||||||
|
|
||||||
|
CompiledIndexExprNode selectExpr;
|
||||||
|
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
||||||
|
selectExpr.operands = {*condition, thenAddress->byteOffset, elseAddress->byteOffset};
|
||||||
|
return CompiledAddressExpr {thenAddress->base, makeCompiledIndexExpr(std::move(selectExpr))};
|
||||||
|
}
|
||||||
|
|
||||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||||
|
|||||||
@@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe
|
|||||||
|
|
||||||
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
||||||
|
|
||||||
|
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
std::string reportsDir = outputDir + "/reports";
|
||||||
|
createDirectory(reportsDir);
|
||||||
|
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::fstream openAppendedReportFile(const std::string& name) {
|
||||||
|
return openAppendedReportFileWithExtension(name, "txt");
|
||||||
|
}
|
||||||
|
|
||||||
std::string formatReportMemory(uint64_t bytes) {
|
std::string formatReportMemory(uint64_t bytes) {
|
||||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
std::fstream openReportFile(const std::string& name);
|
std::fstream openReportFile(const std::string& name);
|
||||||
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||||
|
std::fstream openAppendedReportFile(const std::string& name);
|
||||||
|
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||||
std::string formatReportMemory(uint64_t bytes);
|
std::string formatReportMemory(uint64_t bytes);
|
||||||
|
|
||||||
struct ReportField {
|
struct ReportField {
|
||||||
|
|||||||
@@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio
|
|||||||
++emittedInstructionCount;
|
++emittedInstructionCount;
|
||||||
if (coreJsonStream)
|
if (coreJsonStream)
|
||||||
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
||||||
|
updateScalarRegisterCache(instruction);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const {
|
||||||
|
switch (instruction.opcode) {
|
||||||
|
case pim_binary::Opcode::sldi:
|
||||||
|
scalarRegisterValues[instruction.rd] = instruction.r2OrImm;
|
||||||
|
break;
|
||||||
|
case pim_binary::Opcode::sld:
|
||||||
|
case pim_binary::Opcode::sadd:
|
||||||
|
case pim_binary::Opcode::ssub:
|
||||||
|
case pim_binary::Opcode::smul:
|
||||||
|
case pim_binary::Opcode::saddi:
|
||||||
|
case pim_binary::Opcode::smuli:
|
||||||
|
scalarRegisterValues[instruction.rd].reset();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
||||||
|
auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number");
|
||||||
|
auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||||
|
if (scalarRegisterValues[registerIndex] == immediateValue)
|
||||||
|
return;
|
||||||
|
|
||||||
pim_binary::InstructionRecord instruction;
|
pim_binary::InstructionRecord instruction;
|
||||||
instruction.opcode = pim_binary::Opcode::sldi;
|
instruction.opcode = pim_binary::Opcode::sldi;
|
||||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
instruction.rd = registerIndex;
|
||||||
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
|
instruction.r2OrImm = immediateValue;
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -170,6 +171,7 @@ class PimCodeGen {
|
|||||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||||
std::optional<unsigned> batchLane;
|
std::optional<unsigned> batchLane;
|
||||||
mutable uint32_t emittedInstructionCount = 0;
|
mutable uint32_t emittedInstructionCount = 0;
|
||||||
|
mutable std::array<std::optional<int32_t>, 256> scalarRegisterValues = {};
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge, batchLane);
|
return memory.getValueAddress(value, knowledge, batchLane);
|
||||||
@@ -177,6 +179,7 @@ class PimCodeGen {
|
|||||||
size_t remapCoreId(size_t coreId) const;
|
size_t remapCoreId(size_t coreId) const;
|
||||||
|
|
||||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||||
|
void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const;
|
||||||
|
|
||||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||||
|
|||||||
@@ -32,6 +32,31 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
|
|||||||
llvm::cl::init(PimMemoryReportNone),
|
llvm::cl::init(PimMemoryReportNone),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
|
||||||
|
"pim-conv-lowering",
|
||||||
|
llvm::cl::desc("Convolution lowering strategy for PIM"),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
|
||||||
|
"streamed-patch",
|
||||||
|
"Use streamed/chunked im2col rows without multi-position packing")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
|
||||||
|
"streamed-packed",
|
||||||
|
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
|
||||||
|
"output-channel-tiled",
|
||||||
|
"Force Conv lowering that relies on Gemm output-channel tiling")),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
|
||||||
|
"tiled-2d",
|
||||||
|
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
|
||||||
|
llvm::cl::init(PimConvLoweringAuto),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
pimOnlyCodegen("pim-only-codegen",
|
pimOnlyCodegen("pim-only-codegen",
|
||||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||||
@@ -49,11 +74,46 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
|||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
|
||||||
|
"pim-conv-im2col-max-elements",
|
||||||
|
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
|
||||||
|
llvm::cl::init(1ull << 20),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
|
||||||
|
"pim-conv-stream-chunk-positions",
|
||||||
|
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
|
||||||
|
llvm::cl::init(1024),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
|
||||||
|
llvm::cl::desc("Emit a bounded Conv lowering report"),
|
||||||
|
llvm::cl::init(true),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimDetectCommunicationDeadlock(
|
||||||
|
"pim-detect-communication-deadlock",
|
||||||
|
llvm::cl::desc("Expensively simulate the statically expanded PIM send/receive order at verification time and fail if a blocking communication deadlock is found"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder(
|
||||||
|
"pim-materialize-scalar-fanout-global-order",
|
||||||
|
llvm::cl::desc("Experimental expensive materializer mode: emit scalar-source fanout as globally ordered communication events instead of all-send fanout loops"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimTraceCommunicationMaterialization(
|
||||||
|
"pim-trace-communication-materialization",
|
||||||
|
llvm::cl::desc("Emit verbose materializer-time diagnostics and provenance attributes for every Spatial communication op"),
|
||||||
|
llvm::cl::init(false),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
|
|||||||
@@ -30,19 +30,38 @@ typedef enum {
|
|||||||
PimMemoryReportFull = 2,
|
PimMemoryReportFull = 2,
|
||||||
} PimMemoryReportLevel;
|
} PimMemoryReportLevel;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
PimConvLoweringAuto = 0,
|
||||||
|
PimConvLoweringLegacy = 1,
|
||||||
|
PimConvLoweringDepthwise = 2,
|
||||||
|
PimConvLoweringPackedIm2Col = 3,
|
||||||
|
PimConvLoweringStreamedPatch = 4,
|
||||||
|
PimConvLoweringStreamedPacked = 5,
|
||||||
|
PimConvLoweringOutputChannelTiled = 6,
|
||||||
|
PimConvLoweringInputKTiled = 7,
|
||||||
|
PimConvLoweringTiled2D = 8,
|
||||||
|
} PimConvLoweringType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||||
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
||||||
|
extern llvm::cl::opt<PimConvLoweringType> pimConvLowering;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
extern llvm::cl::opt<bool> pimEmitJson;
|
extern llvm::cl::opt<bool> pimEmitJson;
|
||||||
|
extern llvm::cl::opt<bool> pimReportConvLowering;
|
||||||
|
extern llvm::cl::opt<bool> pimDetectCommunicationDeadlock;
|
||||||
|
extern llvm::cl::opt<bool> pimMaterializeScalarFanoutGlobalOrder;
|
||||||
|
extern llvm::cl::opt<bool> pimTraceCommunicationMaterialization;
|
||||||
|
|
||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
|
extern llvm::cl::opt<uint64_t> pimConvIm2colMaxElements;
|
||||||
|
extern llvm::cl::opt<uint64_t> pimConvStreamChunkPositions;
|
||||||
|
|
||||||
bool hasExplicitPimCoreCount();
|
bool hasExplicitPimCoreCount();
|
||||||
void verifyExplicitPimCoreCount();
|
void verifyExplicitPimCoreCount();
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
|
|
||||||
if (pimEmissionTarget >= EmitSpatial) {
|
if (pimEmissionTarget >= EmitSpatial) {
|
||||||
pm.addPass(createONNXToSpatialPass());
|
pm.addPass(createONNXToSpatialPass());
|
||||||
|
pm.addPass(createSpatialLayoutPlanningPass());
|
||||||
|
pm.addPass(createLowerSpatialPlansPass());
|
||||||
pm.addPass(createMergeComputeNodesPass());
|
pm.addPass(createMergeComputeNodesPass());
|
||||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
Patterns/Tensor/Transpose.cpp
|
Patterns/Tensor/Transpose.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
|
SpatialLayoutPlanningPass.cpp
|
||||||
|
LowerSpatialPlansPass.cpp
|
||||||
Common/AttributeUtils.cpp
|
Common/AttributeUtils.cpp
|
||||||
Common/ComputeRegionBuilder.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
Common/IndexingUtils.cpp
|
Common/IndexingUtils.cpp
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
Value sumTensors(ArrayRef<Value> tensors, PatternRewriter& rewriter) {
|
||||||
if (tensors.size() == 1)
|
if (tensors.size() == 1)
|
||||||
return tensors[0];
|
return tensors[0];
|
||||||
|
|
||||||
|
|||||||
@@ -87,17 +87,17 @@ inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int
|
|||||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
|
||||||
/// the body callback reports failure.
|
/// the body callback reports failure.
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value weight : weights)
|
for (mlir::Value weight : weights)
|
||||||
@@ -124,23 +124,23 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
|
||||||
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||||
template <typename RewriterT, typename BodyFn>
|
template <typename RewriterT, typename BodyFn>
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
mlir::ValueRange weights,
|
mlir::ValueRange weights,
|
||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||||
|
|
||||||
auto* block = new mlir::Block();
|
auto* block = new mlir::Block();
|
||||||
for (mlir::Value weight : weights)
|
for (mlir::Value weight : weights)
|
||||||
@@ -163,15 +163,15 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename RewriterT, typename BodyFn>
|
template <typename RewriterT, typename BodyFn>
|
||||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
auto createSpatGraphComputeBatch(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::TypeRange resultTypes,
|
mlir::TypeRange resultTypes,
|
||||||
int64_t laneCount,
|
int64_t laneCount,
|
||||||
@@ -179,13 +179,13 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
|||||||
mlir::ValueRange inputs,
|
mlir::ValueRange inputs,
|
||||||
BodyFn&& body) {
|
BodyFn&& body) {
|
||||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
|
|
||||||
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
||||||
if (mlir::failed(laneCountAttr))
|
if (mlir::failed(laneCountAttr))
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
|
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||||
|
|
||||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||||
@@ -218,20 +218,53 @@ auto createSpatComputeBatch(RewriterT& rewriter,
|
|||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
std::forward<BodyFn>(body)(args);
|
std::forward<BodyFn>(body)(args);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
rewriter.eraseOp(batchOp);
|
rewriter.eraseOp(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
||||||
}
|
}
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphCompute<NumInputs>(
|
||||||
|
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename RewriterT, typename BodyFn>
|
||||||
|
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::TypeRange resultTypes,
|
||||||
|
int64_t laneCount,
|
||||||
|
mlir::ValueRange weights,
|
||||||
|
mlir::ValueRange inputs,
|
||||||
|
BodyFn&& body) {
|
||||||
|
return createSpatGraphComputeBatch(
|
||||||
|
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
|
||||||
|
}
|
||||||
|
|
||||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::Value source,
|
mlir::Value source,
|
||||||
@@ -262,6 +295,6 @@ mlir::Value materializeOrComputeUnary(mlir::Value input,
|
|||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
SmallVector<Value> sliceTensor(
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||||
assert("Invalid axis" && axis < shape.size());
|
assert("Invalid axis" && axis < shape.size());
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ SmallVector<Value> sliceTensor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value>
|
SmallVector<Value>
|
||||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
sliceVector(const Value& vectorToSlice, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||||
assert("Not a vector" && isVectorShape(shape));
|
assert("Not a vector" && isVectorShape(shape));
|
||||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||||
@@ -137,7 +137,7 @@ sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewr
|
|||||||
}
|
}
|
||||||
|
|
||||||
DenseMap<CoreId, SmallVector<Value>>
|
DenseMap<CoreId, SmallVector<Value>>
|
||||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewriter, Location loc) {
|
||||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||||
@@ -163,6 +163,38 @@ Value extractAxisSlice(
|
|||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Value source,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||||
|
|
||||||
|
bool isIdentitySlice =
|
||||||
|
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||||
|
&& strides.size() == rank;
|
||||||
|
if (isIdentitySlice) {
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||||
|
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||||
|
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||||
|
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||||
|
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
||||||
|
isIdentitySlice = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isIdentitySlice)
|
||||||
|
return source;
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value insertStaticSlice(
|
Value insertStaticSlice(
|
||||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
|||||||
@@ -89,22 +89,30 @@ llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewr
|
|||||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||||
size_t axis,
|
size_t axis,
|
||||||
int64_t sliceSize,
|
int64_t sliceSize,
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||||
int64_t sliceSize,
|
int64_t sliceSize,
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||||
/// current PIM target geometry.
|
/// current PIM target geometry.
|
||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
mlir::Value extractAxisSlice(
|
mlir::Value extractAxisSlice(
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
|
|
||||||
|
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::Value source,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||||
|
|
||||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::Value source,
|
mlir::Value source,
|
||||||
|
|||||||
@@ -19,9 +19,11 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool isWeightLikeComputeOperand(Value value) {
|
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
|
||||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
if (!rankedType)
|
||||||
|
return false;
|
||||||
|
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
|
|||||||
while (auto* definingOp = value.getDefiningOp()) {
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
if (!visited.insert(definingOp).second)
|
if (!visited.insert(definingOp).second)
|
||||||
return false;
|
return false;
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!sourceType)
|
||||||
|
return false;
|
||||||
|
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
|
||||||
|
return false;
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
value = extractSliceOp.getSource();
|
value = extractSliceOp.getSource();
|
||||||
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
|
||||||
|
|
||||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||||
if (auto mapped = mapper.lookupOrNull(value))
|
if (auto mapped = mapper.lookupOrNull(value))
|
||||||
return cast<Value>(mapped);
|
return cast<Value>(mapped);
|
||||||
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isWeightLikeComputeOperand(operand)) {
|
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
|
||||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||||
if (failed(clonedOperand))
|
if (failed(clonedOperand))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -0,0 +1,409 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||||
|
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||||
|
|
||||||
|
struct RowStripPhysicalValue {
|
||||||
|
Value physicalValue;
|
||||||
|
RankedTensorType logicalType;
|
||||||
|
SmallVector<int64_t, 16> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 16> fragmentSizes;
|
||||||
|
std::string indexMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, RowStripPhysicalValue>& rowStripValues,
|
||||||
|
Value value) {
|
||||||
|
auto it = rowStripValues.find(value);
|
||||||
|
if (it == rowStripValues.end())
|
||||||
|
return failure();
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
|
||||||
|
Value physicalValue) {
|
||||||
|
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!logicalType)
|
||||||
|
return blueprint.emitOpError("requires ranked logical output type"), failure();
|
||||||
|
RowStripPhysicalValue value;
|
||||||
|
value.physicalValue = physicalValue;
|
||||||
|
value.logicalType = logicalType;
|
||||||
|
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
|
||||||
|
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
|
||||||
|
value.indexMap = blueprint.getIndexMap().str();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) {
|
||||||
|
auto packedType = cast<RankedTensorType>(input.physicalValue.getType());
|
||||||
|
auto computeOp =
|
||||||
|
createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) {
|
||||||
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) {
|
||||||
|
auto packedType = dyn_cast<RankedTensorType>(rowStripValue.physicalValue.getType());
|
||||||
|
if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw")
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t rank = rowStripValue.logicalType.getRank();
|
||||||
|
const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank;
|
||||||
|
const int64_t packedWidth = packedType.getDimSize(1);
|
||||||
|
const int64_t packedChannels = packedType.getDimSize(2);
|
||||||
|
if (fragmentCount != packedType.getDimSize(0))
|
||||||
|
return failure();
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||||
|
if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex
|
||||||
|
|| rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0)
|
||||||
|
return failure();
|
||||||
|
if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1
|
||||||
|
|| rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto packedSliceType =
|
||||||
|
RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto expandedType =
|
||||||
|
RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto logicalFragmentType =
|
||||||
|
RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding());
|
||||||
|
auto batchOp = createSpatComputeBatch(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {rowStripValue.logicalType},
|
||||||
|
fragmentCount,
|
||||||
|
{},
|
||||||
|
ValueRange {rowStripValue.physicalValue},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
|
SmallVector<OpFoldResult> packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> packedSizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)};
|
||||||
|
Value packedSlice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3));
|
||||||
|
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2},
|
||||||
|
{3}
|
||||||
|
});
|
||||||
|
Value transposeInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType());
|
||||||
|
Value logicalFragment =
|
||||||
|
linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector<int64_t> {0, 3, 1, 2})
|
||||||
|
.getResult()[0];
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> logicalOffsets {
|
||||||
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> logicalSizes {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packedChannels),
|
||||||
|
rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packedWidth)};
|
||||||
|
createParallelInsertSliceIntoBatchOutput(rewriter,
|
||||||
|
loc,
|
||||||
|
logicalFragment,
|
||||||
|
args.outputs.front(),
|
||||||
|
logicalOffsets,
|
||||||
|
logicalSizes,
|
||||||
|
getUnitStrides(rewriter, 4));
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
if (failed(batchOp))
|
||||||
|
return failure();
|
||||||
|
return batchOp->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "lower-spatial-plans"; }
|
||||||
|
StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp moduleOp = getOperation();
|
||||||
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
|
if (failed(entryFunc)) {
|
||||||
|
moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
PatternRewriter rewriter(ctx);
|
||||||
|
llvm::DenseMap<Value, RowStripPhysicalValue> rowStripValues;
|
||||||
|
llvm::SmallPtrSet<Operation*, 16> eraseAfterLowering;
|
||||||
|
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||||
|
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||||
|
return true;
|
||||||
|
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
|
||||||
|
signalPassFailure();
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!verifyLogicalPhase("at the start of LowerSpatialPlans"))
|
||||||
|
return;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||||
|
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||||
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||||
|
});
|
||||||
|
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||||
|
planOp,
|
||||||
|
succeeded(rowStripInput) ? std::optional<Value> {rowStripInput->physicalValue} : std::nullopt,
|
||||||
|
/*emitRowStripLayout=*/true,
|
||||||
|
rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
|
||||||
|
if (failed(rowStripValue)) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rowStripValues[blueprint.getResult()] = *rowStripValue;
|
||||||
|
eraseAfterLowering.insert(planOp);
|
||||||
|
eraseAfterLowering.insert(blueprint);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered =
|
||||||
|
lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected Spatial Conv plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(planOp, *lowered);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||||
|
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||||
|
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||||
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||||
|
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||||
|
});
|
||||||
|
if (outputBlueprint == planOp.getResult().getUsers().end()) {
|
||||||
|
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<RowStripPhysicalValue> input = getRowStripValue(rowStripValues, planOp.getInput());
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
FailureOr<Value> lowered = lowerRowStripRelu(*input, planOp, rewriter);
|
||||||
|
if (failed(lowered)) {
|
||||||
|
planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
|
||||||
|
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
|
||||||
|
if (failed(output)) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rowStripValues[blueprint.getResult()] = *output;
|
||||||
|
eraseAfterLowering.insert(planOp);
|
||||||
|
eraseAfterLowering.insert(blueprint);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(planOp);
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) {
|
||||||
|
auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult());
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(planOp, computeOp.getResults());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto materializeOp = dyn_cast<spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||||
|
if (materializeOp.getSourcePhysicalLayout() == kDenseLayout
|
||||||
|
&& materializeOp.getTargetPhysicalLayout() == kDenseLayout) {
|
||||||
|
rewriter.replaceOp(materializeOp, materializeOp.getInput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout
|
||||||
|
|| materializeOp.getTargetPhysicalLayout() != kDenseLayout) {
|
||||||
|
materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||||
|
if (failed(rowStripValue)) {
|
||||||
|
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.setInsertionPoint(materializeOp);
|
||||||
|
FailureOr<Value> dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter);
|
||||||
|
if (failed(dense)) {
|
||||||
|
materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(materializeOp, *dense);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
|
||||||
|
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
|
||||||
|
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
|
||||||
|
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!eraseAfterLowering.contains(blueprintOp)) {
|
||||||
|
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool erasedAny = true;
|
||||||
|
while (erasedAny) {
|
||||||
|
erasedAny = false;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (!eraseAfterLowering.contains(&op))
|
||||||
|
continue;
|
||||||
|
if (!op.use_empty())
|
||||||
|
continue;
|
||||||
|
eraseAfterLowering.erase(&op);
|
||||||
|
rewriter.eraseOp(&op);
|
||||||
|
erasedAny = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!eraseAfterLowering.empty()) {
|
||||||
|
for (Operation& op : funcOp.getBody().front())
|
||||||
|
if (eraseAfterLowering.contains(&op))
|
||||||
|
op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ConversionTarget helperTarget(*ctx);
|
||||||
|
helperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect,
|
||||||
|
func::FuncDialect>();
|
||||||
|
helperTarget.addLegalOp<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||||
|
helperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||||
|
helperTarget.markOpRecursivelyLegal<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>();
|
||||||
|
|
||||||
|
RewritePatternSet helperPatterns(ctx);
|
||||||
|
populateGemmPatterns(helperPatterns, ctx);
|
||||||
|
populateTransposePatterns(helperPatterns, ctx);
|
||||||
|
if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
FrozenRewritePatternSet nestedHelperPatterns([&] {
|
||||||
|
RewritePatternSet patterns(ctx);
|
||||||
|
populateGemmPatterns(patterns, ctx);
|
||||||
|
populateTransposePatterns(patterns, ctx);
|
||||||
|
return patterns;
|
||||||
|
}());
|
||||||
|
ConversionTarget nestedHelperTarget(*ctx);
|
||||||
|
nestedHelperTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect,
|
||||||
|
func::FuncDialect>();
|
||||||
|
nestedHelperTarget.addIllegalOp<ONNXGemmOp, ONNXTransposeOp>();
|
||||||
|
SmallVector<Operation*> computeLikeOps;
|
||||||
|
funcOp.walk([&](Operation* op) {
|
||||||
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(op))
|
||||||
|
computeLikeOps.push_back(op);
|
||||||
|
});
|
||||||
|
for (Operation* op : computeLikeOps) {
|
||||||
|
if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) {
|
||||||
|
op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!verifyLogicalPhase("after nested helper conversions"))
|
||||||
|
return;
|
||||||
|
bool hasIllegalOps = false;
|
||||||
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
if (isa<ONNXEntryPointOp>(op))
|
||||||
|
return;
|
||||||
|
if (isa<spatial::SpatConv2DPlanOp,
|
||||||
|
spatial::SpatReluPlanOp,
|
||||||
|
spatial::SpatBlueprintOp,
|
||||||
|
spatial::SpatMaterializeLayoutOp>(op)
|
||||||
|
|| op->getDialect()->getNamespace() == "onnx") {
|
||||||
|
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||||
|
hasIllegalOps = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (hasIllegalOps)
|
||||||
|
signalPassFailure();
|
||||||
|
else
|
||||||
|
dumpModule(moduleOp, "spatial1_premerge");
|
||||||
|
|
||||||
|
if (!verifyLogicalPhase("at the end of LowerSpatialPlans"))
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createLowerSpatialPlansPass() { return std::make_unique<LowerSpatialPlansPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -18,6 +18,7 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include "ONNXToSpatialVerifier.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
static void populateEmptyFunction(func::FuncOp funcOp) {
|
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
|
||||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
||||||
if (!computes.empty() || !computeBatches.empty())
|
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
||||||
|
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
||||||
|
SmallVector<spatial::SpatBlueprintOp> blueprints(funcOp.getOps<spatial::SpatBlueprintOp>());
|
||||||
|
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
||||||
|
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty()
|
||||||
|
|| !materializers.empty()) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
sourceLocs.push_back(source.getLoc());
|
sourceLocs.push_back(source.getLoc());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newCompute = spatial::SpatCompute::create(
|
auto newCompute = spatial::SpatGraphCompute::create(
|
||||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||||
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||||
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||||
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
for (Operation& op : funcOp.getOps())
|
for (Operation& op : funcOp.getOps())
|
||||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||||
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||||
|
|
||||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
|
||||||
op.dropAllUses();
|
op.dropAllUses();
|
||||||
rewriter.eraseOp(&op);
|
rewriter.eraseOp(&op);
|
||||||
}
|
}
|
||||||
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
ConversionTarget earlyPostTarget(*ctx);
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("logical Spatial graph verification failed after weight annotation");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
ConversionTarget postTarget(*ctx);
|
ConversionTarget postTarget(*ctx);
|
||||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
|
||||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
|
||||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("logical Spatial graph verification failed before post rewrites");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
RewritePatternSet postPatterns(ctx);
|
RewritePatternSet postPatterns(ctx);
|
||||||
populatePostPatterns(postPatterns, ctx);
|
populatePostPatterns(postPatterns, ctx);
|
||||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||||
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
populateEmptyFunction(*entryFunc);
|
populateEmptyFunction(*entryFunc);
|
||||||
|
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "Common/IR/WeightUtils.hpp"
|
#include "Common/IR/WeightUtils.hpp"
|
||||||
@@ -13,6 +15,8 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr StringLiteral kPhaseMarker = "phase-check";
|
||||||
|
|
||||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
func.walk([&](Operation* op) {
|
func.walk([&](Operation* op) {
|
||||||
if (!hasWeightAlways(op))
|
if (!hasWeightAlways(op))
|
||||||
@@ -23,134 +27,191 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError(
|
illegalOp->emitOpError()
|
||||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
<< kPhaseMarker
|
||||||
|
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Region* getParentRegion(Value value) {
|
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
return candidate && (®ion == candidate || region.isAncestor(candidate));
|
||||||
return blockArg.getOwner()->getParent();
|
|
||||||
if (Operation* definingOp = value.getDefiningOp())
|
|
||||||
return definingOp->getParentRegion();
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isDefinedInsideRegion(Value value, Region& region) {
|
bool isValueDefinedInsideRegion(Value value, Region& region) {
|
||||||
Region* parentRegion = getParentRegion(value);
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
|
||||||
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
|
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalExternalCapture(Value value, Region& region) {
|
||||||
|
if (isValueDefinedInsideRegion(value, region))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
Region& body = compute.getBody();
|
||||||
|
body.walk([&](Operation* nestedOp) {
|
||||||
|
for (OpOperand& operand : nestedOp->getOpOperands()) {
|
||||||
|
Value value = operand.get();
|
||||||
|
if (isLegalExternalCapture(value, body))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||||
|
InFlightDiagnostic diag =
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
|
||||||
|
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
|
||||||
|
diag << " (type " << value.getType() << ")";
|
||||||
|
if (definingOp)
|
||||||
|
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
|
||||||
|
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
||||||
|
if (Operation* owner = blockArg.getOwner()->getParentOp())
|
||||||
|
diag.attachNote(owner->getLoc())
|
||||||
|
<< "external block argument belongs to " << owner->getName().getStringRef();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLegalHostBackedValue(Value value) {
|
bool isLegalHostBackedValue(Value value) {
|
||||||
Operation* definingOp = value.getDefiningOp();
|
Operation* definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
return isa<BlockArgument>(value);
|
return isa<BlockArgument>(value);
|
||||||
|
|
||||||
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return definingOp->getDialect()->getNamespace() != "spat";
|
return definingOp->getDialect()->getNamespace() != "spat";
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
template <typename ComputeOpTy>
|
||||||
ValueRange inputs,
|
void verifyScheduledInputs(ComputeOpTy compute,
|
||||||
bool allowChannelReceiveInputs,
|
bool allowChannelReceiveInputs,
|
||||||
StringRef kind,
|
StringRef kind,
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
unsigned currentInputIndex = inputIndex;
|
|
||||||
Operation* definingOp = input.getDefiningOp();
|
Operation* definingOp = input.getDefiningOp();
|
||||||
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
||||||
continue;
|
continue;
|
||||||
if (isLegalHostBackedValue(input))
|
if (isLegalHostBackedValue(input))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
||||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
InFlightDiagnostic diag = illegalOp->emitOpError()
|
||||||
<< kind << " input #" << currentInputIndex
|
<< kPhaseMarker << " " << kind << " input #" << inputIndex
|
||||||
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
|
||||||
"spat.channel_receive"
|
|
||||||
: " must come from the host");
|
: " must come from the host");
|
||||||
if (definingOp)
|
if (definingOp)
|
||||||
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
|
||||||
});
|
});
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
template <typename ComputeOpTy>
|
||||||
Region& region,
|
void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute,
|
||||||
StringRef kind,
|
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
region.walk([&](Operation* op) {
|
compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) {
|
||||||
for (OpOperand& operand : op->getOpOperands()) {
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
Value value = operand.get();
|
if (!mode || *mode != "fragment_assembly")
|
||||||
if (!isa<TensorType>(value.getType()))
|
return;
|
||||||
continue;
|
diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) {
|
||||||
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization");
|
||||||
continue;
|
});
|
||||||
|
|
||||||
Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
|
||||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
|
||||||
<< "values";
|
|
||||||
diagnostic.attachNote(op->getLoc())
|
|
||||||
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
|
||||||
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
for (Operation& op : funcOp.getOps()) {
|
||||||
|
if (isa<func::ReturnOp,
|
||||||
|
spatial::SpatGraphCompute,
|
||||||
|
spatial::SpatGraphComputeBatch,
|
||||||
|
spatial::SpatConv2DPlanOp,
|
||||||
|
spatial::SpatReluPlanOp,
|
||||||
|
spatial::SpatBlueprintOp,
|
||||||
|
spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
|
||||||
});
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker
|
||||||
|
<< " explicit channel communication is not expected before merge materialization";
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (isCompileTimeOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError()
|
||||||
|
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
for (Operation& op : funcOp.getOps()) {
|
||||||
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
|
||||||
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
|
||||||
for (Operation& op : funcOp.getOps()) {
|
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
|
||||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
|
||||||
continue;
|
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
|
||||||
if (isCompileTimeOp(&op))
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
||||||
continue;
|
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
||||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
|
||||||
illegalOp->emitOpError(
|
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
|
||||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
checkWeightUseChains(funcOp, diagnostics);
|
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
|
||||||
|
|
||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
|
||||||
|
|
||||||
|
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
verifyLogicalTopLevelOps(funcOp, diagnostics);
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
checkWeightUseChains(funcOp, diagnostics);
|
||||||
(void) verifyComputeLikeInputs(
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||||
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
return failure();
|
||||||
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
||||||
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
computeBatchOp.getInputs(),
|
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||||
/*allowChannelReceiveInputs=*/false,
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||||
"spat.compute_batch",
|
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||||
diagnostics);
|
verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics);
|
||||||
verifyNoExternalTensorCaptures(
|
|
||||||
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
|
||||||
}
|
}
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||||
|
verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics);
|
||||||
|
}
|
||||||
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||||
|
return failure();
|
||||||
|
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
|
||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||||
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyNoComputeBodyCaptures(mlir::func::FuncOp funcOp);
|
||||||
|
mlir::LogicalResult verifyLogicalSpatialGraphInvariants(mlir::func::FuncOp funcOp);
|
||||||
|
mlir::LogicalResult verifyScheduledSpatialInvariants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
|
|||||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp);
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp);
|
||||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -285,9 +285,8 @@ static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
|||||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||||
Value bTile =
|
Value bTile = extractStaticSliceOrIdentity(
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
rewriter, loc, args.weights.front(), bTileType, bOffsets, bSizes, unitStrides);
|
||||||
.getResult();
|
|
||||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||||
|
|
||||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
|||||||
@@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
|
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const bool hasNonSingletonOutputBatch =
|
||||||
|
!shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1;
|
||||||
|
if (hasNonSingletonOutputBatch)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
@@ -991,9 +996,19 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
gemmResult =
|
gemmResult =
|
||||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
|
if (shapeInfo->outputBatchShape.empty()) {
|
||||||
rewriter.replaceOp(matmulOp, gemmResult);
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto directOutType =
|
||||||
|
RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
|
||||||
|
Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc);
|
||||||
|
Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc);
|
||||||
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||||
|
|||||||
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
|||||||
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||||
Location loc = reluOp.getLoc();
|
Location loc = reluOp.getLoc();
|
||||||
Type resultType = reluOp.getResult().getType();
|
Type resultType = reluOp.getResult().getType();
|
||||||
constexpr size_t numInputs = 1;
|
auto reluPlan = spatial::SpatReluPlanOp::create(
|
||||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
|
||||||
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
|
rewriter.replaceOp(reluOp, reluPlan.getResult());
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
|
|
||||||
});
|
|
||||||
rewriter.replaceOp(reluOp, computeOp);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
|
||||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
|
||||||
auto promoted = computePromotedOperands(compute);
|
auto promoted = computePromotedOperands(compute);
|
||||||
if (failed(promoted))
|
if (failed(promoted))
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||||
Block& oldBlock = compute.getBody().front();
|
Block& oldBlock = compute.getBody().front();
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute);
|
rewriter.setInsertionPointAfter(compute);
|
||||||
auto newCompute = spatial::SpatCompute::create(
|
auto newCompute = spatial::SpatGraphCompute::create(
|
||||||
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||||
SmallVector<Type> newBlockArgTypes;
|
SmallVector<Type> newBlockArgTypes;
|
||||||
SmallVector<Location> newBlockArgLocs;
|
SmallVector<Location> newBlockArgLocs;
|
||||||
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
|
||||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||||
auto promoted = computePromotedOperands(compute);
|
auto promoted = computePromotedOperands(compute);
|
||||||
if (failed(promoted))
|
if (failed(promoted))
|
||||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||||
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
|
||||||
if (failed(laneCountAttr))
|
if (failed(laneCountAttr))
|
||||||
return failure();
|
return failure();
|
||||||
auto newCompute = spatial::SpatComputeBatch::create(
|
auto newCompute = spatial::SpatGraphComputeBatch::create(
|
||||||
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
|
||||||
auto laneArg = compute.getLaneArgument();
|
auto laneArg = compute.getLaneArgument();
|
||||||
if (!laneArg)
|
if (!laneArg)
|
||||||
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::Value>
|
||||||
|
lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp,
|
||||||
|
std::optional<mlir::Value> rowStripInput,
|
||||||
|
bool emitRowStripLayout,
|
||||||
|
mlir::PatternRewriter& rewriter);
|
||||||
|
|
||||||
|
mlir::LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||||
|
mlir::LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static constexpr StringLiteral kLogicalLayout = "nchw";
|
||||||
|
static constexpr StringLiteral kDenseLayout = "dense_nchw";
|
||||||
|
static constexpr StringLiteral kRowStripLayout = "nchw_row_strip";
|
||||||
|
static constexpr StringLiteral kRowStripIndexMap = "packed_hwc_rows_to_nchw";
|
||||||
|
|
||||||
|
enum class SelectedLayout {
|
||||||
|
DenseNchw,
|
||||||
|
NchwRowStrip,
|
||||||
|
};
|
||||||
|
|
||||||
|
static SelectedLayout getSelectedLayout(llvm::DenseMap<Value, SelectedLayout>& layouts, Value value) {
|
||||||
|
auto it = layouts.find(value);
|
||||||
|
return it == layouts.end() ? SelectedLayout::DenseNchw : it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool usesSelectedRowStrip(Operation* user, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(user))
|
||||||
|
return getSelectedLayout(layouts, reluPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(user))
|
||||||
|
return getSelectedLayout(layouts, convPlan.getResult()) == SelectedLayout::NchwRowStrip;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool allUsersCanHandleRowStrip(Value value, llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
for (Operation* user : value.getUsers()) {
|
||||||
|
if (usesSelectedRowStrip(user, layouts))
|
||||||
|
continue;
|
||||||
|
// Dense-only users must be materialized explicitly.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> buildRowStripMetadata(RankedTensorType type) {
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
const int64_t channels = type.getDimSize(1);
|
||||||
|
const int64_t height = type.getDimSize(2);
|
||||||
|
const int64_t width = type.getDimSize(3);
|
||||||
|
offsets.reserve(height * 4);
|
||||||
|
sizes.reserve(height * 4);
|
||||||
|
for (int64_t row = 0; row < height; ++row) {
|
||||||
|
offsets.append({0, 0, row, 0});
|
||||||
|
sizes.append({1, channels, 1, width});
|
||||||
|
}
|
||||||
|
return {offsets, sizes};
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool canSelectConvRowStrip(spatial::SpatConv2DPlanOp convPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
SelectedLayout inputLayout = getSelectedLayout(layouts, convPlan.getInput());
|
||||||
|
if (inputLayout == SelectedLayout::NchwRowStrip)
|
||||||
|
return succeeded(canConsumeAndProduceRowStrip(convPlan));
|
||||||
|
return succeeded(canLowerConvPlanToRowStrip(convPlan));
|
||||||
|
}
|
||||||
|
|
||||||
|
static SelectedLayout chooseConvLayout(spatial::SpatConv2DPlanOp convPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (!canSelectConvRowStrip(convPlan, layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
if (!allUsersCanHandleRowStrip(convPlan.getResult(), layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
return SelectedLayout::NchwRowStrip;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
if (getSelectedLayout(layouts, reluPlan.getInput()) != SelectedLayout::NchwRowStrip)
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
if (!allUsersCanHandleRowStrip(reluPlan.getResult(), layouts))
|
||||||
|
return SelectedLayout::DenseNchw;
|
||||||
|
return SelectedLayout::NchwRowStrip;
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatBlueprintOp insertRowStripBlueprint(IRRewriter& rewriter, Value value) {
|
||||||
|
auto outputType = cast<RankedTensorType>(value.getType());
|
||||||
|
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
||||||
|
return spatial::SpatBlueprintOp::create(rewriter,
|
||||||
|
value.getLoc(),
|
||||||
|
outputType,
|
||||||
|
value,
|
||||||
|
ValueRange {},
|
||||||
|
rewriter.getStringAttr(kLogicalLayout),
|
||||||
|
rewriter.getStringAttr(kRowStripLayout),
|
||||||
|
rewriter.getDenseI64ArrayAttr(offsets),
|
||||||
|
rewriter.getDenseI64ArrayAttr(sizes),
|
||||||
|
rewriter.getStringAttr(kRowStripIndexMap),
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void materializeDenseUses(IRRewriter& rewriter,
|
||||||
|
Value layoutValue,
|
||||||
|
llvm::DenseMap<Value, SelectedLayout>& layouts) {
|
||||||
|
SmallVector<OpOperand*> denseUses;
|
||||||
|
for (OpOperand& use : layoutValue.getUses()) {
|
||||||
|
if (usesSelectedRowStrip(use.getOwner(), layouts))
|
||||||
|
continue;
|
||||||
|
denseUses.push_back(&use);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (OpOperand* use : denseUses) {
|
||||||
|
Operation* owner = use->getOwner();
|
||||||
|
rewriter.setInsertionPoint(owner);
|
||||||
|
auto materialized = spatial::SpatMaterializeLayoutOp::create(rewriter,
|
||||||
|
owner->getLoc(),
|
||||||
|
use->get().getType(),
|
||||||
|
use->get(),
|
||||||
|
rewriter.getStringAttr(kLogicalLayout),
|
||||||
|
rewriter.getStringAttr(kRowStripLayout),
|
||||||
|
rewriter.getStringAttr(kDenseLayout));
|
||||||
|
use->set(materialized.getResult());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass, OperationPass<ModuleOp>> {
|
||||||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialLayoutPlanningPass)
|
||||||
|
|
||||||
|
StringRef getArgument() const override { return "spatial-layout-planning"; }
|
||||||
|
StringRef getDescription() const override { return "Select conservative Spatial layouts and insert reconciliation barriers."; }
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto entryFunc = getPimEntryFunc(getOperation());
|
||||||
|
if (failed(entryFunc)) {
|
||||||
|
getOperation().emitError("failed to locate the PIM entry function during Spatial layout planning");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
llvm::DenseMap<Value, SelectedLayout> layouts;
|
||||||
|
|
||||||
|
bool changed = true;
|
||||||
|
while (changed) {
|
||||||
|
changed = false;
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||||
|
SelectedLayout selected = chooseConvLayout(convPlan, layouts);
|
||||||
|
if (layouts[convPlan.getResult()] != selected) {
|
||||||
|
layouts[convPlan.getResult()] = selected;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||||
|
SelectedLayout selected = chooseReluLayout(reluPlan, layouts);
|
||||||
|
if (layouts[reluPlan.getResult()] != selected) {
|
||||||
|
layouts[reluPlan.getResult()] = selected;
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||||
|
Value producedValue;
|
||||||
|
if (auto convPlan = dyn_cast<spatial::SpatConv2DPlanOp>(&op))
|
||||||
|
producedValue = convPlan.getResult();
|
||||||
|
else if (auto reluPlan = dyn_cast<spatial::SpatReluPlanOp>(&op))
|
||||||
|
producedValue = reluPlan.getResult();
|
||||||
|
else
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (getSelectedLayout(layouts, producedValue) != SelectedLayout::NchwRowStrip)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(&op);
|
||||||
|
auto blueprint = insertRowStripBlueprint(rewriter, producedValue);
|
||||||
|
rewriter.replaceAllUsesExcept(producedValue, blueprint.getResult(), blueprint);
|
||||||
|
materializeDenseUses(rewriter, blueprint.getResult(), layouts);
|
||||||
|
}
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||||
|
getOperation().emitError("logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createSpatialLayoutPlanningPass() { return std::make_unique<SpatialLayoutPlanningPass>(); }
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
@@ -26,7 +27,83 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
static bool isMaterializableExternalTensorOp(Operation* op) {
|
||||||
|
return isa<spatial::SpatChannelReceiveOp,
|
||||||
|
spatial::SpatExtractRowsOp,
|
||||||
|
tensor::ExtractSliceOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CollapseShapeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO REMOVE THIS UGLY FIX
|
||||||
|
//TODO: Remove this helper once compute_batch external tensor captures are
|
||||||
|
// fixed at the producer side.
|
||||||
|
//
|
||||||
|
// This function is a temporary SpatialToPim repair path. It clones selected
|
||||||
|
// external tensor producers, such as channel_receive and tensor view/slice ops,
|
||||||
|
// into the new pim.core_batch body when the old spat.compute_batch body refers
|
||||||
|
// to tensor values defined outside the batch.
|
||||||
|
//
|
||||||
|
// The real invariant should be stronger:
|
||||||
|
//
|
||||||
|
// A spat.compute_batch body must not capture external tensor values.
|
||||||
|
// Every tensor used inside the body must be either:
|
||||||
|
// - a compute_batch block argument,
|
||||||
|
// - defined inside the compute_batch body,
|
||||||
|
// - or a legal constant-like value.
|
||||||
|
//
|
||||||
|
// If this invariant is violated, the responsible producer, most likely merge
|
||||||
|
// schedule materialization, should emit verifier-clean Spatial IR instead of
|
||||||
|
// relying on SpatialToPim to clone external producer chains later.
|
||||||
|
//
|
||||||
|
// After that producer-side fix:
|
||||||
|
// 1. remove isMaterializableExternalTensorOp,
|
||||||
|
// 2. remove materializeExternalTensorValue,
|
||||||
|
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
|
||||||
|
// tensor operand,
|
||||||
|
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
|
||||||
|
// before SpatialToPim.
|
||||||
|
//
|
||||||
|
// Be careful not to replace every external tensor capture with a normal
|
||||||
|
// compute_batch input blindly: host-backed tensors and explicit inter-core
|
||||||
|
// communication have different semantics. In particular, channel_receive-like
|
||||||
|
// values should be materialized through the communication model, not silently
|
||||||
|
// treated as host inputs.
|
||||||
|
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Block& oldBlock,
|
||||||
|
Value value,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
if (mapper.contains(value))
|
||||||
|
return mapper.lookup(value);
|
||||||
|
|
||||||
|
if (!isa<TensorType>(value.getType()))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (definingOp->getBlock() == &oldBlock)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!isMaterializableExternalTensorOp(definingOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (Value operand : definingOp->getOperands()) {
|
||||||
|
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
|
||||||
|
if (succeeded(materializedOperand))
|
||||||
|
mapper.map(operand, *materializedOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(*definingOp, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
|
||||||
|
return mapper.lookup(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
size_t& fallbackCoreId) {
|
size_t& fallbackCoreId) {
|
||||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
@@ -54,6 +131,151 @@ static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
|||||||
return result.getUses().begin()->getOperandNumber();
|
return result.getUses().begin()->getOperandNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct BatchFragmentAssemblyPlan {
|
||||||
|
unsigned returnIndex = 0;
|
||||||
|
int64_t localSourceElementOffset = 0;
|
||||||
|
int64_t fragmentByteSize = 0;
|
||||||
|
SmallVector<int64_t, 8> hostOffsetsByLane;
|
||||||
|
};
|
||||||
|
|
||||||
|
static Value createLaneIndexedOffset(IRRewriter& rewriter, Operation* anchor, Value laneArg, ArrayRef<int64_t> values, Location loc) {
|
||||||
|
assert(!values.empty() && "expected lane-indexed values");
|
||||||
|
if (llvm::all_of(values.drop_front(), [&](int64_t value) { return value == values.front(); }))
|
||||||
|
return getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||||
|
|
||||||
|
if (values.size() >= 2) {
|
||||||
|
int64_t step = values[1] - values[0];
|
||||||
|
bool arithmetic = llvm::all_of(llvm::seq<size_t>(2, values.size()), [&](size_t index) {
|
||||||
|
return values[index] == values.front() + static_cast<int64_t>(index) * step;
|
||||||
|
});
|
||||||
|
if (arithmetic) {
|
||||||
|
Value base = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||||
|
if (step == 0)
|
||||||
|
return base;
|
||||||
|
Value stepValue = getOrCreateIndexConstant(rewriter, anchor, step);
|
||||||
|
Value scaledLane = arith::MulIOp::create(rewriter, loc, laneArg, stepValue).getResult();
|
||||||
|
return arith::AddIOp::create(rewriter, loc, base, scaledLane).getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value selected = getOrCreateIndexConstant(rewriter, anchor, values.front());
|
||||||
|
for (auto [lane, value] : llvm::enumerate(values.drop_front())) {
|
||||||
|
Value laneValue = getOrCreateIndexConstant(rewriter, anchor, static_cast<int64_t>(lane + 1));
|
||||||
|
Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, laneArg, laneValue);
|
||||||
|
Value candidate = getOrCreateIndexConstant(rewriter, anchor, value);
|
||||||
|
selected = arith::SelectOp::create(rewriter, loc, cmp, candidate, selected);
|
||||||
|
}
|
||||||
|
return selected;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>>
|
||||||
|
analyzeTopLevelFragmentAssemblyUses(OpResult result, RankedTensorType packedResultType, uint32_t laneCount) {
|
||||||
|
SmallVector<BatchFragmentAssemblyPlan, 8> plans;
|
||||||
|
if (!packedResultType.hasStaticShape() || laneCount == 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t packedElementCount = packedResultType.getNumElements();
|
||||||
|
if (packedElementCount % static_cast<int64_t>(laneCount) != 0)
|
||||||
|
return failure();
|
||||||
|
int64_t payloadElementCount = packedElementCount / static_cast<int64_t>(laneCount);
|
||||||
|
size_t elementSize = getElementTypeSizeInBytes(packedResultType.getElementType());
|
||||||
|
|
||||||
|
for (OpOperand& use : result.getUses()) {
|
||||||
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
|
||||||
|
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
|
||||||
|
return failure();
|
||||||
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
|
if (!mode || *mode != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr)
|
||||||
|
return failure();
|
||||||
|
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
unsigned returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||||
|
auto hostResultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!hostResultType || !hostResultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||||
|
int64_t rank = hostResultType.getRank();
|
||||||
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
rank,
|
||||||
|
fragmentOperands.size(),
|
||||||
|
operandIndices,
|
||||||
|
sourceOffsets,
|
||||||
|
flatOffsets,
|
||||||
|
flatSizes,
|
||||||
|
flatStrides)))
|
||||||
|
return failure();
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
if (operandIndices[fragmentIndex] != static_cast<int64_t>(use.getOperandNumber()))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
int64_t sourceElementOffset = sourceOffsets[fragmentIndex];
|
||||||
|
int64_t lane = sourceElementOffset / payloadElementCount;
|
||||||
|
int64_t localSourceElementOffset = sourceElementOffset % payloadElementCount;
|
||||||
|
if (lane < 0 || lane >= static_cast<int64_t>(laneCount))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return failure();
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(forEachContiguousDestinationChunk(
|
||||||
|
hostResultType.getShape(),
|
||||||
|
fragmentOffsets,
|
||||||
|
fragmentSizes,
|
||||||
|
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||||
|
int64_t hostElementOffset = 0;
|
||||||
|
SmallVector<int64_t> hostStrides = computeRowMajorStrides(hostResultType.getShape());
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(chunkOffsets))
|
||||||
|
hostElementOffset += offset * hostStrides[dim];
|
||||||
|
int64_t hostByteOffset = hostElementOffset * static_cast<int64_t>(elementSize);
|
||||||
|
int64_t fragmentByteSize = chunkElements * static_cast<int64_t>(elementSize);
|
||||||
|
int64_t chunkSourceOffset = localSourceElementOffset + relativeSourceOffset;
|
||||||
|
|
||||||
|
auto planIt = llvm::find_if(plans, [&](const BatchFragmentAssemblyPlan& plan) {
|
||||||
|
return plan.returnIndex == returnIndex && plan.localSourceElementOffset == chunkSourceOffset
|
||||||
|
&& plan.fragmentByteSize == fragmentByteSize;
|
||||||
|
});
|
||||||
|
if (planIt == plans.end()) {
|
||||||
|
BatchFragmentAssemblyPlan plan;
|
||||||
|
plan.returnIndex = returnIndex;
|
||||||
|
plan.localSourceElementOffset = chunkSourceOffset;
|
||||||
|
plan.fragmentByteSize = fragmentByteSize;
|
||||||
|
plan.hostOffsetsByLane.assign(laneCount, std::numeric_limits<int64_t>::min());
|
||||||
|
plan.hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||||
|
plans.push_back(std::move(plan));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
planIt->hostOffsetsByLane[static_cast<size_t>(lane)] = hostByteOffset;
|
||||||
|
return success();
|
||||||
|
})))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const BatchFragmentAssemblyPlan& plan : plans)
|
||||||
|
if (llvm::any_of(plan.hostOffsetsByLane, [](int64_t offset) { return offset == std::numeric_limits<int64_t>::min(); }))
|
||||||
|
return failure();
|
||||||
|
return plans;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||||
if (scale == 1)
|
if (scale == 1)
|
||||||
return base;
|
return base;
|
||||||
@@ -62,26 +284,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
|
|||||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<OpFoldResult, 4> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
strides.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
tensor::ParallelInsertSliceOp insertSlice,
|
Location loc,
|
||||||
ShapedType destinationType,
|
ShapedType destinationType,
|
||||||
|
ArrayRef<OpFoldResult> mixedOffsets,
|
||||||
|
ArrayRef<int64_t> additionalOffsets,
|
||||||
IRMapping& mapper) {
|
IRMapping& mapper) {
|
||||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||||
|
|
||||||
Value totalOffset;
|
Value totalOffset;
|
||||||
Location loc = insertSlice.getLoc();
|
for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) {
|
||||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
|
||||||
int64_t scale = strides[dim] * elementBytes;
|
int64_t scale = strides[dim] * elementBytes;
|
||||||
Value scaledOffset;
|
Value scaledOffset;
|
||||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
assert(intAttr && "expected integer offset attribute");
|
assert(intAttr && "expected integer offset attribute");
|
||||||
scaledOffset =
|
scaledOffset = getOrCreateIndexConstant(rewriter,
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
}
|
(intAttr.getInt() + additionalOffsets[dim]) * scale);
|
||||||
else {
|
} else {
|
||||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||||
|
if (additionalOffsets[dim] != 0) {
|
||||||
|
Value staticOffset = getOrCreateIndexConstant(rewriter,
|
||||||
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
|
additionalOffsets[dim] * scale);
|
||||||
|
scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
totalOffset =
|
totalOffset =
|
||||||
@@ -93,9 +338,139 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
return totalOffset;
|
return totalOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
|
tensor::ParallelInsertSliceOp insertSlice,
|
||||||
|
ShapedType destinationType,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
SmallVector<int64_t> zeroOffsets(destinationType.getRank(), 0);
|
||||||
|
return createHostTargetOffset(rewriter,
|
||||||
|
insertSlice.getLoc(),
|
||||||
|
destinationType,
|
||||||
|
insertSlice.getMixedOffsets(),
|
||||||
|
zeroOffsets,
|
||||||
|
mapper);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ArrayRef<OpFoldResult> baseOffsets,
|
||||||
|
ArrayRef<int64_t> fragmentOffsets,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
SmallVector<OpFoldResult, 4> combined;
|
||||||
|
combined.reserve(fragmentOffsets.size());
|
||||||
|
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||||
|
int64_t base = cast<IntegerAttr>(attr).getInt();
|
||||||
|
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
|
||||||
|
if (fragmentOffsets[dim] == 0) {
|
||||||
|
combined.push_back(dynamicBase);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value staticOffset =
|
||||||
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
|
||||||
|
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
|
||||||
|
}
|
||||||
|
return combined;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||||
|
spatial::SpatBlueprintOp blueprint,
|
||||||
|
Value hostTarget,
|
||||||
|
ArrayRef<OpFoldResult> baseOffsets,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||||
|
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||||
|
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||||
|
return blueprint.emitOpError(
|
||||||
|
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
|
if (!sourceOffsetsAttr)
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires explicit source offsets");
|
||||||
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
rank,
|
||||||
|
fragmentOperands.size(),
|
||||||
|
operandIndices,
|
||||||
|
sourceOffsets,
|
||||||
|
flatOffsets,
|
||||||
|
flatSizes,
|
||||||
|
flatStrides)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
|
fragmentShape.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||||
|
|
||||||
|
Value fragment = source;
|
||||||
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||||
|
blueprint, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||||
|
if (failed(extractOffsets))
|
||||||
|
return failure();
|
||||||
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
source,
|
||||||
|
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
fragment,
|
||||||
|
hostTarget,
|
||||||
|
buildFragmentOffsets(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
baseOffsets,
|
||||||
|
fragmentOffsets,
|
||||||
|
mapper),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostTarget;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
IRRewriter& rewriter) {
|
IRRewriter& rewriter) {
|
||||||
Location loc = computeBatchOp.getLoc();
|
Location loc = computeBatchOp.getLoc();
|
||||||
Block& oldBlock = computeBatchOp.getBody().front();
|
Block& oldBlock = computeBatchOp.getBody().front();
|
||||||
@@ -130,14 +505,29 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
|
||||||
|
|
||||||
SmallVector<unsigned> returnOperandIndices;
|
SmallVector<unsigned> returnOperandIndices;
|
||||||
|
SmallVector<SmallVector<BatchFragmentAssemblyPlan, 1>, 4> fragmentAssemblyPlansByResult;
|
||||||
if (computeBatchOp.getNumResults() != 0) {
|
if (computeBatchOp.getNumResults() != 0) {
|
||||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||||
|
fragmentAssemblyPlansByResult.resize(computeBatchOp.getNumResults());
|
||||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||||
|
if (result.use_empty())
|
||||||
|
continue;
|
||||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||||
if (failed(returnOperandIndex))
|
if (succeeded(returnOperandIndex)) {
|
||||||
|
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return computeBatchOp.emitOpError(
|
||||||
|
"resultful compute_batch publication lowering requires static ranked tensor results");
|
||||||
|
FailureOr<SmallVector<BatchFragmentAssemblyPlan, 8>> fragmentAssemblyPlans =
|
||||||
|
analyzeTopLevelFragmentAssemblyUses(cast<OpResult>(result), resultType, computeBatchOp.getLaneCount());
|
||||||
|
if (failed(fragmentAssemblyPlans))
|
||||||
return computeBatchOp.emitOpError(
|
return computeBatchOp.emitOpError(
|
||||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
fragmentAssemblyPlansByResult[resultIndex].assign(fragmentAssemblyPlans->begin(), fragmentAssemblyPlans->end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,6 +585,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
if (isa<spatial::SpatYieldOp>(op))
|
if (isa<spatial::SpatYieldOp>(op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
for (Operation* user : blueprint.getOutput().getUsers()) {
|
||||||
|
if (!isa<tensor::ParallelInsertSliceOp>(user))
|
||||||
|
return blueprint.emitOpError(
|
||||||
|
"fragment assembly blueprint lowering expects only tensor.parallel_insert_slice users");
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||||
if (!firstOutputArg)
|
if (!firstOutputArg)
|
||||||
@@ -211,10 +613,62 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||||
if (resultIndex >= returnOperandIndices.size())
|
if (resultIndex >= returnOperandIndices.size())
|
||||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||||
|
bool hasDirectReturn = returnOperandIndices[resultIndex] != std::numeric_limits<unsigned>::max();
|
||||||
|
bool hasFragmentAssembly = resultIndex < fragmentAssemblyPlansByResult.size()
|
||||||
|
&& !fragmentAssemblyPlansByResult[resultIndex].empty();
|
||||||
|
if (!hasDirectReturn && !hasFragmentAssembly)
|
||||||
|
continue;
|
||||||
|
|
||||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||||
|
|
||||||
|
if (hasFragmentAssembly) {
|
||||||
|
BlockArgument laneArg = coreBatchOp.getLaneArgument();
|
||||||
|
auto mappedSourceType = dyn_cast<ShapedType>(mappedSource.getType());
|
||||||
|
if (!mappedSourceType || !mappedSourceType.hasStaticShape())
|
||||||
|
return insertSlice.emitOpError("fragment assembly batch lowering requires a static ranked lane-local source");
|
||||||
|
for (const BatchFragmentAssemblyPlan& plan : fragmentAssemblyPlansByResult[resultIndex]) {
|
||||||
|
Value outputTensor = outputTensors[plan.returnIndex](rewriter, insertSlice.getLoc());
|
||||||
|
auto sizeAttr = pim::getCheckedI32Attr(
|
||||||
|
rewriter, coreBatchOp.getOperation(), plan.fragmentByteSize, "fragment assembly host copy byte size");
|
||||||
|
if (failed(sizeAttr))
|
||||||
|
return failure();
|
||||||
|
Value hostTargetOffset =
|
||||||
|
createLaneIndexedOffset(rewriter, coreBatchOp.getOperation(), laneArg, plan.hostOffsetsByLane, insertSlice.getLoc());
|
||||||
|
Value deviceSourceOffset = getOrCreateIndexConstant(
|
||||||
|
rewriter, coreBatchOp.getOperation(),
|
||||||
|
plan.localSourceElementOffset * static_cast<int64_t>(getElementTypeSizeInBytes(mappedSourceType.getElementType())));
|
||||||
|
outputTensor =
|
||||||
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
insertSlice.getLoc(),
|
||||||
|
outputTensor.getType(),
|
||||||
|
hostTargetOffset,
|
||||||
|
deviceSourceOffset,
|
||||||
|
outputTensor,
|
||||||
|
mappedSource,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||||
|
if (auto blueprint =
|
||||||
|
insertSlice.getSource().getDefiningOp<spatial::SpatBlueprintOp>()) {
|
||||||
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
||||||
|
blueprint,
|
||||||
|
hostTarget,
|
||||||
|
insertSlice.getMixedOffsets(),
|
||||||
|
mapper);
|
||||||
|
if (failed(updatedHostTarget))
|
||||||
|
return failure();
|
||||||
|
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
||||||
@@ -264,9 +718,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
Operation* definingOp = operand.getDefiningOp();
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||||
continue;
|
continue;
|
||||||
|
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
continue;
|
||||||
|
|
||||||
return computeBatchOp.emitOpError(
|
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
|
||||||
"expected external tensor communication to be materialized in Spatial before batch lowering");
|
continue;
|
||||||
|
|
||||||
|
InFlightDiagnostic diagnostic =
|
||||||
|
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||||
|
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
|
||||||
|
if (definingOp)
|
||||||
|
diagnostic << " from external producer '" << definingOp->getName() << "'";
|
||||||
|
return diagnostic;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
@@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
|
|||||||
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatBlueprintOp blueprint,
|
||||||
|
int64_t resultRank,
|
||||||
|
size_t operandCount,
|
||||||
|
ArrayRef<int64_t> operandIndices,
|
||||||
|
ArrayRef<int64_t> sourceOffsets,
|
||||||
|
ArrayRef<int64_t> flatOffsets,
|
||||||
|
ArrayRef<int64_t> flatSizes,
|
||||||
|
ArrayRef<int64_t> flatStrides) {
|
||||||
|
if (operandIndices.size() != sourceOffsets.size())
|
||||||
|
return blueprint.emitOpError("fragment assembly operand index and source offset counts must match");
|
||||||
|
if (flatOffsets.size() != flatSizes.size())
|
||||||
|
return blueprint.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
||||||
|
if (flatStrides.size() != flatOffsets.size())
|
||||||
|
return blueprint.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
||||||
|
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
||||||
|
return blueprint.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
||||||
|
|
||||||
|
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
||||||
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
||||||
|
return blueprint.emitOpError("fragment assembly operand index is out of range");
|
||||||
|
if (sourceOffsets[fragmentIndex] < 0)
|
||||||
|
return blueprint.emitOpError("fragment assembly source offsets must be nonnegative");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||||
|
indices[dim] = flatIndex % shape[dim];
|
||||||
|
flatIndex /= shape[dim];
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t, 4>>
|
||||||
|
getStaticSliceOffsetsForElementOffset(Operation* anchor,
|
||||||
|
ShapedType sourceType,
|
||||||
|
ArrayRef<int64_t> fragmentShape,
|
||||||
|
int64_t sourceElementOffset,
|
||||||
|
StringRef fieldName) {
|
||||||
|
if (!sourceType.hasStaticShape())
|
||||||
|
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
|
||||||
|
if (sourceElementOffset < 0)
|
||||||
|
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
|
||||||
|
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
|
||||||
|
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
|
||||||
|
|
||||||
|
int64_t sourceElementCount = sourceType.getNumElements();
|
||||||
|
int64_t fragmentElementCount = 1;
|
||||||
|
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||||
|
if (fragmentShape[dim] < 0)
|
||||||
|
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
|
||||||
|
fragmentElementCount *= fragmentShape[dim];
|
||||||
|
}
|
||||||
|
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
|
||||||
|
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
|
||||||
|
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||||
|
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
|
||||||
|
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
|
||||||
|
}
|
||||||
|
return sliceOffsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
||||||
|
ArrayRef<int64_t> baseOffsets,
|
||||||
|
ArrayRef<int64_t> sizes,
|
||||||
|
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
|
||||||
|
int64_t rank = static_cast<int64_t>(sizes.size());
|
||||||
|
int64_t suffixStart = rank - 1;
|
||||||
|
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
|
||||||
|
--suffixStart;
|
||||||
|
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
|
||||||
|
suffixStart = 0;
|
||||||
|
else
|
||||||
|
++suffixStart;
|
||||||
|
|
||||||
|
int64_t chunkElements = 1;
|
||||||
|
for (int64_t dim = suffixStart; dim < rank; ++dim)
|
||||||
|
chunkElements *= sizes[dim];
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
|
||||||
|
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
|
||||||
|
int64_t sourceChunkOrdinal = 0;
|
||||||
|
|
||||||
|
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
|
||||||
|
if (dim == static_cast<int64_t>(prefixExtents.size())) {
|
||||||
|
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
|
||||||
|
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
|
||||||
|
chunkOffsets[prefixDim] += current[prefixDim];
|
||||||
|
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
|
||||||
|
return failure();
|
||||||
|
++sourceChunkOrdinal;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
|
||||||
|
current[dim] = index;
|
||||||
|
if (failed(visit(visit, dim + 1)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (prefixExtents.empty())
|
||||||
|
return callback(baseOffsets, 0, chunkElements);
|
||||||
|
return visit(visit, 0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir::spatial {
|
||||||
|
class SpatBlueprintOp;
|
||||||
|
}
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::FailureOr<mlir::IntegerAttr>
|
mlir::FailureOr<mlir::IntegerAttr>
|
||||||
@@ -29,6 +36,29 @@ mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operat
|
|||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
|
mlir::LogicalResult validateFragmentAssemblyMetadata(onnx_mlir::spatial::SpatBlueprintOp blueprint,
|
||||||
|
int64_t resultRank,
|
||||||
|
size_t operandCount,
|
||||||
|
llvm::ArrayRef<int64_t> operandIndices,
|
||||||
|
llvm::ArrayRef<int64_t> sourceOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> flatOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> flatSizes,
|
||||||
|
llvm::ArrayRef<int64_t> flatStrides);
|
||||||
|
|
||||||
|
mlir::FailureOr<mlir::SmallVector<int64_t, 4>>
|
||||||
|
getStaticSliceOffsetsForElementOffset(mlir::Operation* anchor,
|
||||||
|
mlir::ShapedType sourceType,
|
||||||
|
llvm::ArrayRef<int64_t> fragmentShape,
|
||||||
|
int64_t sourceElementOffset,
|
||||||
|
llvm::StringRef fieldName);
|
||||||
|
|
||||||
|
mlir::LogicalResult
|
||||||
|
forEachContiguousDestinationChunk(llvm::ArrayRef<int64_t> destShape,
|
||||||
|
llvm::ArrayRef<int64_t> baseOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(llvm::ArrayRef<int64_t>, int64_t, int64_t)>
|
||||||
|
callback);
|
||||||
|
|
||||||
inline mlir::tensor::EmptyOp
|
inline mlir::tensor::EmptyOp
|
||||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
return mlir::tensor::EmptyOp::create(rewriter, loc, shapedType.getShape(), shapedType.getElementType());
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigne
|
|||||||
return operandNumber - inputBegin;
|
return operandNumber - inputBegin;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||||
return getInputIndex(owner, compute.getInputs().size());
|
return getInputIndex(owner, compute.getInputs().size());
|
||||||
|
|
||||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
|
if (auto computeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(owner))
|
||||||
return getInputIndex(owner, computeBatch.getInputs().size());
|
return getInputIndex(owner, computeBatch.getInputs().size());
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@@ -32,13 +32,13 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
Value replacement) {
|
Value replacement) {
|
||||||
Block& body = owner->getRegion(0).front();
|
Block& body = owner->getRegion(0).front();
|
||||||
BlockArgument bodyArgument;
|
BlockArgument bodyArgument;
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner)) {
|
||||||
auto computeArg = compute.getInputArgument(inputIndex);
|
auto computeArg = compute.getInputArgument(inputIndex);
|
||||||
assert(computeArg && "expected compute input block argument");
|
assert(computeArg && "expected compute input block argument");
|
||||||
bodyArgument = *computeArg;
|
bodyArgument = *computeArg;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
auto batchArg = cast<spatial::SpatScheduledComputeBatch>(owner).getInputArgument(inputIndex);
|
||||||
assert(batchArg && "expected compute_batch input block argument");
|
assert(batchArg && "expected compute_batch input block argument");
|
||||||
bodyArgument = *batchArg;
|
bodyArgument = *batchArg;
|
||||||
}
|
}
|
||||||
@@ -46,10 +46,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
|||||||
|
|
||||||
rewriter.startOpModification(owner);
|
rewriter.startOpModification(owner);
|
||||||
bodyArgument.replaceAllUsesWith(replacement);
|
bodyArgument.replaceAllUsesWith(replacement);
|
||||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
|
if (auto compute = dyn_cast<spatial::SpatScheduledCompute>(owner))
|
||||||
compute.getInputsMutable().erase(inputIndex);
|
compute.getInputsMutable().erase(inputIndex);
|
||||||
else
|
else
|
||||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
cast<spatial::SpatScheduledComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||||
body.eraseArgument(bodyArgIndex);
|
body.eraseArgument(bodyArgIndex);
|
||||||
rewriter.finalizeOpModification(owner);
|
rewriter.finalizeOpModification(owner);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,105 @@ static bool isChannelUseChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ShapedType destinationType,
|
||||||
|
ArrayRef<int64_t> fragmentOffsets) {
|
||||||
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
|
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||||
|
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
|
||||||
|
byteOffset += offset * strides[dim] * elementBytes;
|
||||||
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> lowerFragmentAssemblyBlueprint(IRRewriter& rewriter,
|
||||||
|
spatial::SpatBlueprintOp blueprint,
|
||||||
|
IRMapping& mapping) {
|
||||||
|
auto resultType = dyn_cast<ShapedType>(blueprint.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
|
|
||||||
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = blueprint.getFragmentStrides();
|
||||||
|
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !sourceOffsetsAttr
|
||||||
|
|| !fragmentStridesAttr)
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
rank,
|
||||||
|
fragmentOperands.size(),
|
||||||
|
operandIndices,
|
||||||
|
sourceOffsets,
|
||||||
|
flatOffsets,
|
||||||
|
flatSizes,
|
||||||
|
flatStrides)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value currentOutput = createEmptyTensorFromShaped(rewriter, blueprint.getLoc(), resultType);
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentElements *= flatSizes[flatIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return blueprint.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
int64_t fragmentBytes =
|
||||||
|
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
||||||
|
blueprint.getOperation(),
|
||||||
|
fragmentBytes,
|
||||||
|
"fragment assembly host copy size");
|
||||||
|
if (failed(sizeAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, blueprint.getLoc(), resultType, fragmentOffsets);
|
||||||
|
auto deviceSourceOffsetBytes = pim::checkedMul(static_cast<uint64_t>(sourceOffsets[fragmentIndex]),
|
||||||
|
static_cast<uint64_t>(getElementTypeSizeInBytes(sourceType.getElementType())),
|
||||||
|
blueprint,
|
||||||
|
"fragment assembly device source offset");
|
||||||
|
if (failed(deviceSourceOffsetBytes))
|
||||||
|
return failure();
|
||||||
|
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
||||||
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
|
static_cast<int64_t>(*deviceSourceOffsetBytes));
|
||||||
|
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
currentOutput.getType(),
|
||||||
|
hostTargetOffset,
|
||||||
|
deviceSourceOffset,
|
||||||
|
currentOutput,
|
||||||
|
source,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentOutput;
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
@@ -55,7 +154,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatScheduledCompute computeOp, size_t& fallbackCoreId) {
|
||||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
return pim::checkedI32(spatialCoreIdAttr.getInt(), computeOp, "spatial compute core id");
|
||||||
auto checkedCoreId =
|
auto checkedCoreId =
|
||||||
@@ -66,7 +165,7 @@ static FailureOr<int32_t> getPimCoreIdForComputeOp(spatial::SpatCompute computeO
|
|||||||
return *checkedCoreId;
|
return *checkedCoreId;
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain,
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
bool requireReturnUse = true) {
|
bool requireReturnUse = true) {
|
||||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
@@ -104,13 +203,13 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatScheduledCompute computeOp,
|
||||||
IRRewriter& rewriter,
|
IRRewriter& rewriter,
|
||||||
OperationFolder& constantFolder) {
|
OperationFolder& constantFolder) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
return false;
|
return false;
|
||||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
return isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
}))
|
}))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@@ -131,6 +230,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
mapping.map(*weightArg, weight);
|
mapping.map(*weightArg, weight);
|
||||||
}
|
}
|
||||||
for (Operation& op : block.without_terminator()) {
|
for (Operation& op : block.without_terminator()) {
|
||||||
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
|
std::optional<StringRef> modeAttr = blueprint.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
auto lowered = lowerFragmentAssemblyBlueprint(rewriter, blueprint, mapping);
|
||||||
|
if (failed(lowered))
|
||||||
|
return false;
|
||||||
|
mapping.map(blueprint.getOutput(), *lowered);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||||
@@ -145,7 +255,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||||
IRRewriter& rewriter,
|
IRRewriter& rewriter,
|
||||||
OperationFolder& constantFolder) {
|
OperationFolder& constantFolder) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -11,6 +15,108 @@ namespace raptor {
|
|||||||
|
|
||||||
} // namespace raptor
|
} // namespace raptor
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<OpFoldResult, 4> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
strides.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerFragmentAssemblyBlueprintPattern
|
||||||
|
: OpConversionPattern<spatial::SpatBlueprintOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatBlueprintOp op,
|
||||||
|
OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
std::optional<StringRef> modeAttr = op.getMode();
|
||||||
|
if (!modeAttr || *modeAttr != "fragment_assembly")
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
|
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = op.getFragmentSourceOffsets();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
||||||
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr)
|
||||||
|
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(
|
||||||
|
op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value currentOutput =
|
||||||
|
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = fragmentOperands[operandIndex];
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
|
fragmentShape.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||||
|
|
||||||
|
Value fragment = source;
|
||||||
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) {
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> extractOffsets = getStaticSliceOffsetsForElementOffset(
|
||||||
|
op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice");
|
||||||
|
if (failed(extractOffsets))
|
||||||
|
return failure();
|
||||||
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
source,
|
||||||
|
getStaticIndexAttrs(rewriter, *extractOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
currentOutput = tensor::InsertSliceOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
fragment,
|
||||||
|
currentOutput,
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, currentOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateInitialPatterns(RewritePatternSet& patterns) {
|
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||||
raptor::populateWithGenerated(patterns);
|
raptor::populateWithGenerated(patterns);
|
||||||
populateTransposeLoweringPatterns(patterns);
|
populateTransposeLoweringPatterns(patterns);
|
||||||
@@ -19,6 +125,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
|
|||||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||||
raptor::populateWithGenerated(patterns);
|
raptor::populateWithGenerated(patterns);
|
||||||
populateTransposeLoweringPatterns(patterns);
|
populateTransposeLoweringPatterns(patterns);
|
||||||
|
patterns.add<LowerFragmentAssemblyBlueprintPattern>(patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -10,6 +10,14 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
|
||||||
|
for (NamedAttribute attr : source->getAttrs()) {
|
||||||
|
StringRef name = attr.getName().strref();
|
||||||
|
if (name.starts_with("raptor."))
|
||||||
|
target->setAttr(attr.getName(), attr.getValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -17,7 +25,8 @@ struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
|||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
|
||||||
|
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -37,9 +46,10 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
|||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
Value received = pim::PimReceiveOp::create(
|
auto receive = pim::PimReceiveOp::create(
|
||||||
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId())
|
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
|
||||||
.getOutput();
|
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
|
||||||
|
Value received = receive.getOutput();
|
||||||
rewriter.replaceOp(op, received);
|
rewriter.replaceOp(op, received);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (auto& uses : extractSliceOp->getUses()) {
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
if (isa<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||||
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
|
|
||||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(uses.getOwner())) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -92,7 +92,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
replaceAndEraseDirectComputeLikeInput(
|
replaceAndEraseDirectComputeLikeInput(
|
||||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(uses.getOwner())) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -114,7 +114,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
{
|
{
|
||||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatScheduledCompute>()) {
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
@@ -125,7 +125,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatScheduledComputeBatch>()) {
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
@@ -179,7 +179,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
auto argUser = argUses.getOwner();
|
auto argUser = argUses.getOwner();
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(argUser)) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -191,7 +191,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatScheduledComputeBatch>(argUser)) {
|
||||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||||
if (!inputIndex)
|
if (!inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ getCheckedByteOffset(int64_t elementOffset, size_t elementSize, Operation* ancho
|
|||||||
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
return pim::checkedCast<int64_t>(*byteOffset, anchor, fieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatScheduledCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain) {
|
SmallVectorImpl<Operation*>& helperChain) {
|
||||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -149,6 +149,40 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>>
|
||||||
|
analyzeTopLevelFragmentAssemblyUses(Value value) {
|
||||||
|
SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4> uses;
|
||||||
|
for (OpOperand& use : value.getUses()) {
|
||||||
|
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(use.getOwner());
|
||||||
|
if (!blueprint || blueprint->getParentOp() != blueprint->getParentOfType<func::FuncOp>())
|
||||||
|
return failure();
|
||||||
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
|
if (!mode || *mode != "fragment_assembly")
|
||||||
|
return failure();
|
||||||
|
if (!blueprint.getOutput().hasOneUse() || !isa<func::ReturnOp>(*blueprint.getOutput().getUsers().begin()))
|
||||||
|
return failure();
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr || !resultType || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
SmallVector<Value> fragmentOperands {blueprint.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, blueprint.getFragments());
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
resultType.getRank(),
|
||||||
|
fragmentOperands.size(),
|
||||||
|
*operandIndicesAttr,
|
||||||
|
*sourceOffsetsAttr,
|
||||||
|
blueprint.getFragmentOffsets(),
|
||||||
|
blueprint.getFragmentSizes(),
|
||||||
|
*stridesAttr)))
|
||||||
|
return failure();
|
||||||
|
uses.emplace_back(blueprint, use.getOperandNumber());
|
||||||
|
}
|
||||||
|
return uses;
|
||||||
|
}
|
||||||
|
|
||||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||||
auto getConcatResult = [](Operation* op) -> Value {
|
auto getConcatResult = [](Operation* op) -> Value {
|
||||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
@@ -212,7 +246,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
if (auto helperCompute = dyn_cast<spatial::SpatScheduledCompute>(currentUser)) {
|
||||||
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
@@ -559,6 +593,116 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<std::pair<spatial::SpatBlueprintOp, size_t>, 4>> fragmentAssemblyUses =
|
||||||
|
analyzeTopLevelFragmentAssemblyUses(producedValue);
|
||||||
|
if (succeeded(fragmentAssemblyUses)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape()) {
|
||||||
|
producerOp->emitOpError("fragment assembly publication requires a static ranked tensor source");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t elementSize = getElementTypeSizeInBytes(sourceType.getElementType());
|
||||||
|
for (auto [blueprint, operandNumber] : *fragmentAssemblyUses) {
|
||||||
|
rewriter.setInsertionPointAfterValue(storedValue);
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = blueprint.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> sourceOffsetsAttr = blueprint.getFragmentSourceOffsets();
|
||||||
|
std::optional<ArrayRef<int64_t>> stridesAttr = blueprint.getFragmentStrides();
|
||||||
|
if (!operandIndicesAttr || !sourceOffsetsAttr || !stridesAttr) {
|
||||||
|
blueprint.emitOpError(
|
||||||
|
"fragment assembly lowering requires explicit operand, source-offset, and stride metadata");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t returnIndex = blueprint.getOutput().getUses().begin()->getOperandNumber();
|
||||||
|
Value outputTensor = outputTensors[returnIndex](rewriter, loc);
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(outputTensor.getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||||
|
if (!outputType || !resultType || !resultType.hasStaticShape()) {
|
||||||
|
blueprint.emitOpError("fragment assembly lowering requires static ranked host outputs");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> sourceOffsets = *sourceOffsetsAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = blueprint.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = blueprint.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *stridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
if (failed(validateFragmentAssemblyMetadata(blueprint,
|
||||||
|
rank,
|
||||||
|
1 + blueprint.getFragments().size(),
|
||||||
|
operandIndices,
|
||||||
|
sourceOffsets,
|
||||||
|
flatOffsets,
|
||||||
|
flatSizes,
|
||||||
|
flatStrides)))
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
if (operandIndices[fragmentIndex] != static_cast<int64_t>(operandNumber))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1) {
|
||||||
|
blueprint.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentSizes.push_back(flatSizes[flatIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool failedChunk = false;
|
||||||
|
if (failed(forEachContiguousDestinationChunk(
|
||||||
|
outputType.getShape(),
|
||||||
|
fragmentOffsets,
|
||||||
|
fragmentSizes,
|
||||||
|
[&](ArrayRef<int64_t> chunkOffsets, int64_t relativeSourceOffset, int64_t chunkElements) -> LogicalResult {
|
||||||
|
auto hostOffset =
|
||||||
|
getCheckedByteOffset(computeFlatElementIndex(chunkOffsets, outputType.getShape()),
|
||||||
|
elementSize,
|
||||||
|
producerOp,
|
||||||
|
"fragment assembly host offset");
|
||||||
|
auto sourceOffset = getCheckedByteOffset(sourceOffsets[fragmentIndex] + relativeSourceOffset,
|
||||||
|
elementSize,
|
||||||
|
producerOp,
|
||||||
|
"fragment assembly source offset");
|
||||||
|
auto fragmentBytes =
|
||||||
|
getCheckedByteOffset(chunkElements, elementSize, producerOp, "fragment assembly host copy byte size");
|
||||||
|
if (failed(hostOffset) || failed(sourceOffset) || failed(fragmentBytes)) {
|
||||||
|
failedChunk = true;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto sizeAttr =
|
||||||
|
pim::getCheckedI32Attr(rewriter, producerOp, *fragmentBytes, "fragment assembly host copy byte size");
|
||||||
|
if (failed(sizeAttr)) {
|
||||||
|
failedChunk = true;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
outputTensor =
|
||||||
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
blueprint.getLoc(),
|
||||||
|
outputTensor.getType(),
|
||||||
|
getOrCreateIndexConstant(rewriter, producerOp, *hostOffset),
|
||||||
|
getOrCreateIndexConstant(rewriter, producerOp, *sourceOffset),
|
||||||
|
outputTensor,
|
||||||
|
storedValue,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
return success();
|
||||||
|
})))
|
||||||
|
failedChunk = true;
|
||||||
|
if (failedChunk)
|
||||||
|
return ReturnPathLoweringResult::Failure;
|
||||||
|
}
|
||||||
|
markOpToRemove(blueprint.getOperation());
|
||||||
|
}
|
||||||
|
return ReturnPathLoweringResult::Handled;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||||
auto storedByteSize =
|
auto storedByteSize =
|
||||||
@@ -643,7 +787,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
}
|
}
|
||||||
|
|
||||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
spatial::SpatScheduledCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,7 +800,7 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||||
Operation* onlyUser = *op->getUsers().begin();
|
Operation* onlyUser = *op->getUsers().begin();
|
||||||
isExclusivelyOwnedByReturnChain =
|
isExclusivelyOwnedByReturnChain =
|
||||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatScheduledCompute>(onlyUser)
|
||||||
|| isReturnHelperChainOp(onlyUser);
|
|| isReturnHelperChainOp(onlyUser);
|
||||||
}
|
}
|
||||||
if (!isExclusivelyOwnedByReturnChain)
|
if (!isExclusivelyOwnedByReturnChain)
|
||||||
@@ -669,7 +813,17 @@ void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp ret
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(op)) {
|
||||||
|
std::optional<StringRef> mode = blueprint.getMode();
|
||||||
|
if (mode && *mode == "fragment_assembly") {
|
||||||
|
markOpToRemove(blueprint.getOperation());
|
||||||
|
for (Value operand : blueprint->getOperands())
|
||||||
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (!computeOp.getInputs().empty())
|
if (!computeOp.getInputs().empty())
|
||||||
for (Value input : computeOp.getInputs())
|
for (Value input : computeOp.getInputs())
|
||||||
|
|||||||
@@ -25,9 +25,11 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Common/IR/ShapeUtils.hpp"
|
||||||
#include "Common/IR/ConstantUtils.hpp"
|
#include "Common/IR/ConstantUtils.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Common/Support/CheckedArithmetic.hpp"
|
#include "Common/Support/CheckedArithmetic.hpp"
|
||||||
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||||
@@ -97,6 +99,64 @@ static FailureOr<Value> createZeroedDeviceHVector(IRRewriter& rewriter,
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isHostBackedMemRefValue(Value value) {
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return isa<memref::GetGlobalOp>(definingOp);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isHostBackedTensorValue(Value value) {
|
||||||
|
while (Operation* definingOp = value.getDefiningOp()) {
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(extractSliceOp.getSource().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getResult().getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return false;
|
||||||
|
if (!onnx_mlir::isContiguousSubviewWithDynamicOffsets(sourceType.getShape(),
|
||||||
|
extractSliceOp.getMixedOffsets(),
|
||||||
|
extractSliceOp.getStaticSizes(),
|
||||||
|
extractSliceOp.getStaticStrides())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto castOp = dyn_cast<tensor::CastOp>(definingOp)) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(definingOp))
|
||||||
|
return isHostBackedMemRefValue(toTensorOp.getBuffer());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<Value>
|
static FailureOr<Value>
|
||||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||||
@@ -120,6 +180,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
|||||||
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter, zeroed->getDefiningOp(), *byteSize, "device padding copy byte size");
|
||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
if (isHostBackedTensorValue(vector)) {
|
||||||
|
return PimMemCopyHostToDevOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroIndex, zeroIndex, *zeroed, vector, *sizeAttr).getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,6 +201,12 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
func::FuncOp funcOp = *entryFunc;
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
if (failed(verifyScheduledSpatialInvariants(funcOp))) {
|
||||||
|
funcOp.emitOpError(
|
||||||
|
"scheduled Spatial verification failed at the start of SpatialToPim");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
OperationFolder constantFolder(&getContext());
|
OperationFolder constantFolder(&getContext());
|
||||||
@@ -176,19 +246,19 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
computeOp.emitOpError("failed to lower spat.scheduled_compute to pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||||
markOpToRemove(computeBatchOp);
|
markOpToRemove(computeBatchOp);
|
||||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
computeBatchOp.emitOpError("failed to lower spat.scheduled_compute_batch to pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -374,7 +444,7 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatScheduledCompute>(op)) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
||||||
continue;
|
continue;
|
||||||
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
||||||
|
|||||||
@@ -41,8 +41,11 @@ private:
|
|||||||
|
|
||||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
lowerComputeOp(spatial::SpatScheduledCompute computeOp,
|
||||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
mlir::IRRewriter& rewriter,
|
||||||
|
mlir::OperationFolder& constantFolder);
|
||||||
|
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
|
mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
enum class ReturnPathLoweringResult {
|
enum class ReturnPathLoweringResult {
|
||||||
Handled,
|
Handled,
|
||||||
@@ -51,7 +54,7 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp,
|
||||||
mlir::OpResult result,
|
mlir::OpResult result,
|
||||||
mlir::Value yieldValue,
|
mlir::Value yieldValue,
|
||||||
mlir::IRRewriter& rewriter);
|
mlir::IRRewriter& rewriter);
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ using namespace bufferization;
|
|||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue,
|
||||||
|
Location loc,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const StaticValueKnowledge& knowledge) {
|
||||||
bool isContiguous =
|
bool isContiguous =
|
||||||
succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
succeeded(resolveContiguousAddress(memrefValue, knowledge)) || succeeded(compileContiguousAddressExpr(memrefValue));
|
||||||
if (isContiguous && isDeviceLocalPimAddress(memrefValue))
|
if (isContiguous && isDeviceLocalPimAddress(memrefValue, knowledge))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
@@ -32,7 +35,7 @@ FailureOr<Value> materializeContiguousInputMemRef(Value memrefValue, Location lo
|
|||||||
if (failed(sizeAttr))
|
if (failed(sizeAttr))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (isHostBackedPimAddress(memrefValue)) {
|
if (isHostBackedPimAddress(memrefValue, knowledge)) {
|
||||||
return PimMemCopyHostToDevOp::create(
|
return PimMemCopyHostToDevOp::create(
|
||||||
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
rewriter, loc, contiguousType, zeroOffset, zeroOffset, contiguousBuffer, memrefValue, *sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
|
|||||||
@@ -3,10 +3,15 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Value>
|
llvm::FailureOr<mlir::Value>
|
||||||
materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
materializeContiguousInputMemRef(mlir::Value memrefValue,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::RewriterBase& rewriter,
|
||||||
|
const onnx_mlir::StaticValueKnowledge& knowledge = {});
|
||||||
mlir::Value
|
mlir::Value
|
||||||
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,26 @@ using namespace bufferization;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace pim {
|
namespace pim {
|
||||||
|
|
||||||
|
static StaticValueKnowledge getEnclosingBufferizationKnowledge(Operation* op) {
|
||||||
|
StaticValueKnowledge knowledge;
|
||||||
|
|
||||||
|
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||||
|
knowledge.indexValues[coreBatchOp.getLaneArgument()] = 0;
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreBatchOp.getWeights()))
|
||||||
|
knowledge.aliases[coreBatchOp.getWeightArgument(index)] = weight;
|
||||||
|
for (auto [index, input] : llvm::enumerate(coreBatchOp.getInputs()))
|
||||||
|
knowledge.aliases[coreBatchOp.getInputArgument(index)] = input;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights()))
|
||||||
|
knowledge.aliases[coreOp.getWeightArgument(index)] = weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
struct MemCopyHostToDevOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -148,7 +168,8 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguous = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguous =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguous))
|
if (failed(contiguous))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(*contiguous);
|
inputs.push_back(*contiguous);
|
||||||
@@ -182,7 +203,8 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
|||||||
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -410,7 +432,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -456,7 +479,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -497,10 +521,12 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
auto contiguousLhs =
|
||||||
|
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousLhs))
|
if (failed(contiguousLhs))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
auto contiguousRhs =
|
||||||
|
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousRhs))
|
if (failed(contiguousRhs))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -534,10 +560,12 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
auto contiguousLhs =
|
||||||
|
materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousLhs))
|
if (failed(contiguousLhs))
|
||||||
return failure();
|
return failure();
|
||||||
auto contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
auto contiguousRhs =
|
||||||
|
materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousRhs))
|
if (failed(contiguousRhs))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
@@ -574,7 +602,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
auto contiguousInput =
|
||||||
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter, getEnclosingBufferizationKnowledge(op));
|
||||||
if (failed(contiguousInput))
|
if (failed(contiguousInput))
|
||||||
return failure();
|
return failure();
|
||||||
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|||||||
@@ -116,6 +116,36 @@ lowerMemRefCopyToPimCopy(memref::CopyOp copyOp, PatternRewriter& rewriter, const
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyHostToDevOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getHostSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getDeviceTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getHostSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceTarget(), knowledge);
|
||||||
|
if (!sourceIsHost || !targetIsDevice || targetIsHost || sourceIsDevice)
|
||||||
|
return copyOp.emitOpError("pim.memcp_hd requires a host-backed source and a device-local target");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyDevToHostOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getDeviceSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getHostTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getDeviceSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getHostTarget(), knowledge);
|
||||||
|
if (!targetIsHost || !sourceIsDevice || sourceIsHost || targetIsDevice)
|
||||||
|
return copyOp.emitOpError("pim.memcp_dh requires a device-local source and a host-backed target");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyLoweredPimCopy(pim::PimMemCopyOp copyOp, const StaticValueKnowledge& knowledge) {
|
||||||
|
bool sourceIsHost = isHostBackedPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsHost = isHostBackedPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
bool sourceIsDevice = isDeviceLocalPimAddress(copyOp.getSource(), knowledge);
|
||||||
|
bool targetIsDevice = isDeviceLocalPimAddress(copyOp.getTarget(), knowledge);
|
||||||
|
if (!sourceIsDevice || !targetIsDevice || sourceIsHost || targetIsHost)
|
||||||
|
return copyOp.emitOpError("pim.memcp requires device-local source and target operands");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimBufferizationPass)
|
||||||
StringRef getArgument() const override { return "bufferize-pim"; }
|
StringRef getArgument() const override { return "bufferize-pim"; }
|
||||||
@@ -129,6 +159,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||||
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||||
|
LogicalResult verifyPimCopyAddressSpaces(ModuleOp moduleOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
static LogicalResult applyPatternsOnce(Operation* op, PatternApplicator& applicator, PatternRewriter& rewriter) {
|
||||||
@@ -240,6 +271,10 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (failed(verifyPimCopyAddressSpaces(moduleOp))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
@@ -346,6 +381,31 @@ LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp mod
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimBufferizationPass::verifyPimCopyAddressSpaces(ModuleOp moduleOp) const {
|
||||||
|
bool hasFailure = false;
|
||||||
|
auto verifyWithKnowledge = [&](auto coreLikeOp, const StaticValueKnowledge& initialKnowledge) {
|
||||||
|
(void) walkPimCoreBlockStructurally(
|
||||||
|
coreLikeOp.getBody().front(), initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyOp>(&op); copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyHostToDevOp>(&op);
|
||||||
|
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
if (auto copyOp = dyn_cast<pim::PimMemCopyDevToHostOp>(&op);
|
||||||
|
copyOp && failed(verifyLoweredPimCopy(copyOp, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
moduleOp.walk([&](pim::PimCoreOp coreOp) { verifyWithKnowledge(coreOp, seedCoreKnowledge(coreOp)); });
|
||||||
|
moduleOp.walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||||
|
StaticValueKnowledge knowledge = seedCoreBatchKnowledge(coreBatchOp, 0);
|
||||||
|
verifyWithKnowledge(coreBatchOp, knowledge);
|
||||||
|
});
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -96,8 +96,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override {
|
||||||
auto coreOp = mapOp->getParentOfType<pim::PimCoreOp>();
|
if (!mapOp->getParentOfType<pim::PimCoreOp>() && !mapOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
if (!coreOp)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
auto initType = dyn_cast<MemRefType>(mapOp.getInit().getType());
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ add_pim_library(OMPimVerification
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
OMPimCommon
|
OMPimCommon
|
||||||
|
OMPimCompilerOptions
|
||||||
OMPimBufferization
|
OMPimBufferization
|
||||||
PimOps
|
PimOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
|
|||||||
@@ -5,12 +5,17 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||||
@@ -143,6 +148,479 @@ static bool isHostAddressableValue(Value value, const StaticValueKnowledge& know
|
|||||||
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
return isa_and_nonnull<memref::GetGlobalOp>(base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
enum class CommunicationEventKind { Send, Receive };
|
||||||
|
|
||||||
|
struct CommunicationEvent {
|
||||||
|
CommunicationEventKind kind = CommunicationEventKind::Send;
|
||||||
|
int64_t coreId = 0;
|
||||||
|
int64_t peerCoreId = 0;
|
||||||
|
int64_t size = 0;
|
||||||
|
uint64_t ordinal = 0;
|
||||||
|
std::optional<int64_t> minChannelId;
|
||||||
|
std::string materializer;
|
||||||
|
std::optional<int64_t> traceId;
|
||||||
|
std::optional<int64_t> commOrder;
|
||||||
|
std::optional<int64_t> traceClassId;
|
||||||
|
std::optional<int64_t> traceBlockOrdinal;
|
||||||
|
std::string traceKind;
|
||||||
|
std::string tracePhase;
|
||||||
|
std::string traceClassKind;
|
||||||
|
std::string tracePayload;
|
||||||
|
std::string traceMessages;
|
||||||
|
std::string tracePrevOp;
|
||||||
|
std::string traceNextOp;
|
||||||
|
Operation* op = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CommunicationEventVector = SmallVector<CommunicationEvent, 0>;
|
||||||
|
|
||||||
|
static StringRef getCommunicationEventKindName(CommunicationEventKind kind) {
|
||||||
|
return kind == CommunicationEventKind::Send ? "send" : "receive";
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr StringLiteral kRaptorMinChannelIdAttr = "raptor.min_channel_id";
|
||||||
|
constexpr StringLiteral kRaptorMaterializerAttr = "raptor.materializer";
|
||||||
|
constexpr StringLiteral kRaptorCommOrderAttr = "raptor.comm_order";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceIdAttr = "raptor.comm_trace_id";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceKindAttr = "raptor.comm_trace_kind";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages";
|
||||||
|
constexpr StringLiteral kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op";
|
||||||
|
constexpr StringLiteral kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op";
|
||||||
|
|
||||||
|
static std::optional<int64_t> getNearestIntegerAttr(Operation* op, StringRef name) {
|
||||||
|
for (Operation* current = op; current; current = current->getParentOp())
|
||||||
|
if (auto attr = current->getAttrOfType<IntegerAttr>(name))
|
||||||
|
return attr.getInt();
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string getNearestStringAttr(Operation* op, StringRef name) {
|
||||||
|
for (Operation* current = op; current; current = current->getParentOp())
|
||||||
|
if (auto attr = current->getAttrOfType<StringAttr>(name))
|
||||||
|
return attr.getValue().str();
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatLocation(Location loc) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
loc.print(os);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatOperationSummary(Operation* op) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
OpPrintingFlags flags;
|
||||||
|
flags.skipRegions();
|
||||||
|
flags.elideLargeElementsAttrs();
|
||||||
|
op->print(os, flags);
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string formatCommunicationEvent(const CommunicationEvent& event) {
|
||||||
|
std::string text;
|
||||||
|
llvm::raw_string_ostream os(text);
|
||||||
|
os << "core " << event.coreId << " " << getCommunicationEventKindName(event.kind) << " "
|
||||||
|
<< (event.kind == CommunicationEventKind::Send ? "to" : "from") << " " << event.peerCoreId
|
||||||
|
<< " size " << event.size << "B ordinal " << event.ordinal;
|
||||||
|
if (event.minChannelId)
|
||||||
|
os << " min_channel " << *event.minChannelId;
|
||||||
|
if (event.commOrder)
|
||||||
|
os << " comm_order " << *event.commOrder;
|
||||||
|
if (!event.materializer.empty())
|
||||||
|
os << " materializer " << event.materializer;
|
||||||
|
if (event.traceId)
|
||||||
|
os << " trace#" << *event.traceId;
|
||||||
|
if (!event.tracePhase.empty())
|
||||||
|
os << " phase " << event.tracePhase;
|
||||||
|
if (event.traceClassId)
|
||||||
|
os << " class " << event.traceClassKind << "#" << *event.traceClassId;
|
||||||
|
if (event.traceBlockOrdinal)
|
||||||
|
os << " block_ordinal " << *event.traceBlockOrdinal;
|
||||||
|
if (!event.tracePayload.empty())
|
||||||
|
os << " payload " << event.tracePayload;
|
||||||
|
if (!event.traceMessages.empty())
|
||||||
|
os << " messages {" << event.traceMessages << "}";
|
||||||
|
if (!event.tracePrevOp.empty() || !event.traceNextOp.empty())
|
||||||
|
os << " inserted_between [" << event.tracePrevOp << " | " << event.traceNextOp << "]";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool areMatchedCommunicationEvents(const CommunicationEvent& lhs, const CommunicationEvent& rhs) {
|
||||||
|
if (lhs.coreId != rhs.peerCoreId || lhs.peerCoreId != rhs.coreId || lhs.size != rhs.size)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return (lhs.kind == CommunicationEventKind::Send && rhs.kind == CommunicationEventKind::Receive)
|
||||||
|
|| (lhs.kind == CommunicationEventKind::Receive && rhs.kind == CommunicationEventKind::Send);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static std::optional<size_t> findMatchingCounterpartIndex(const CommunicationEventVector& events,
|
||||||
|
const CommunicationEvent& event,
|
||||||
|
size_t begin) {
|
||||||
|
for (size_t index = begin; index < events.size(); ++index)
|
||||||
|
if (areMatchedCommunicationEvents(event, events[index]))
|
||||||
|
return index;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCounterpartProbe(llvm::raw_ostream& os,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
const CommunicationEvent& blockedEvent) {
|
||||||
|
auto peerEventsIt = coreEvents.find(blockedEvent.peerCoreId);
|
||||||
|
if (peerEventsIt == coreEvents.end()) {
|
||||||
|
os << " no local stream was collected for peer core " << blockedEvent.peerCoreId << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CommunicationEventVector& peerEvents = peerEventsIt->second;
|
||||||
|
size_t peerPc = 0;
|
||||||
|
auto peerPcIt = programCounters.find(blockedEvent.peerCoreId);
|
||||||
|
if (peerPcIt != programCounters.end())
|
||||||
|
peerPc = peerPcIt->second;
|
||||||
|
|
||||||
|
os << " counterpart probe for " << formatCommunicationEvent(blockedEvent) << "\n";
|
||||||
|
os << " peer core " << blockedEvent.peerCoreId << " current pc " << peerPc << " of " << peerEvents.size()
|
||||||
|
<< "\n";
|
||||||
|
|
||||||
|
std::optional<size_t> nextMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, peerPc);
|
||||||
|
std::optional<size_t> anyMatch = findMatchingCounterpartIndex(peerEvents, blockedEvent, 0);
|
||||||
|
|
||||||
|
if (!nextMatch && !anyMatch) {
|
||||||
|
os << " no matching counterpart exists anywhere in the peer stream\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nextMatch && anyMatch) {
|
||||||
|
os << " matching counterpart exists only before the peer pc at ordinal " << *anyMatch
|
||||||
|
<< "; this usually means the static stream expansion or ordering metadata is inconsistent\n";
|
||||||
|
os << " " << formatCommunicationEvent(peerEvents[*anyMatch]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(peerEvents[*anyMatch].op) << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CommunicationEvent& match = peerEvents[*nextMatch];
|
||||||
|
os << " next matching counterpart is at peer ordinal " << *nextMatch << " (distance +"
|
||||||
|
<< (*nextMatch >= peerPc ? *nextMatch - peerPc : 0) << ")\n";
|
||||||
|
os << " " << formatCommunicationEvent(match) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(match.op) << "\n";
|
||||||
|
|
||||||
|
if (*nextMatch == peerPc)
|
||||||
|
return;
|
||||||
|
|
||||||
|
os << " peer operations blocking before that counterpart:\n";
|
||||||
|
size_t end = std::min(peerEvents.size(), std::min(*nextMatch + static_cast<size_t>(1), peerPc + static_cast<size_t>(12)));
|
||||||
|
for (size_t index = peerPc; index < end; ++index) {
|
||||||
|
os << (index == peerPc ? " pc => " : " ") << "#" << index << " "
|
||||||
|
<< formatCommunicationEvent(peerEvents[index]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(peerEvents[index].op) << "\n";
|
||||||
|
}
|
||||||
|
if (end <= *nextMatch)
|
||||||
|
os << " ... " << (*nextMatch - end + 1) << " more peer communication event(s) before the counterpart\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static CommunicationEvent makeCommunicationEvent(CommunicationEventKind kind,
|
||||||
|
int64_t coreId,
|
||||||
|
int64_t peerCoreId,
|
||||||
|
int64_t size,
|
||||||
|
uint64_t ordinal,
|
||||||
|
Operation* op) {
|
||||||
|
return CommunicationEvent {kind,
|
||||||
|
coreId,
|
||||||
|
peerCoreId,
|
||||||
|
size,
|
||||||
|
ordinal,
|
||||||
|
getNearestIntegerAttr(op, kRaptorMinChannelIdAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorMaterializerAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceIdAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommOrderAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceClassIdAttr),
|
||||||
|
getNearestIntegerAttr(op, kRaptorCommTraceBlockOrdinalAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceKindAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePhaseAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceClassKindAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePayloadAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceMessagesAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTracePrevOpAttr),
|
||||||
|
getNearestStringAttr(op, kRaptorCommTraceNextOpAttr),
|
||||||
|
op};
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult appendCoreCommunicationEvents(Block& block,
|
||||||
|
int64_t coreId,
|
||||||
|
const StaticValueKnowledge& initialKnowledge,
|
||||||
|
SmallVectorImpl<CommunicationEvent>& events,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
return walkPimCoreBlock(block, initialKnowledge, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
|
if (auto sendOp = dyn_cast<pim::PimSendOp>(&op)) {
|
||||||
|
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||||
|
if (failed(targetCoreId)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("cannot statically resolve send target core for PIM communication deadlock check");
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(makeCommunicationEvent(CommunicationEventKind::Send,
|
||||||
|
coreId,
|
||||||
|
*targetCoreId,
|
||||||
|
sendOp.getSize(),
|
||||||
|
static_cast<uint64_t>(events.size()),
|
||||||
|
&op));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(&op)) {
|
||||||
|
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||||
|
if (failed(sourceCoreId)) {
|
||||||
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("cannot statically resolve receive source core for PIM communication deadlock check");
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
events.push_back(makeCommunicationEvent(CommunicationEventKind::Receive,
|
||||||
|
coreId,
|
||||||
|
*sourceCoreId,
|
||||||
|
receiveOp.getSize(),
|
||||||
|
static_cast<uint64_t>(events.size()),
|
||||||
|
&op));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCommunicationWindow(llvm::raw_ostream& os,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
int64_t coreId,
|
||||||
|
size_t pc,
|
||||||
|
unsigned radius = 4) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const CommunicationEventVector& events = eventsIt->second;
|
||||||
|
size_t begin = pc > radius ? pc - radius : 0;
|
||||||
|
size_t end = std::min(events.size(), pc + static_cast<size_t>(radius) + 1);
|
||||||
|
os << " local stream for core " << coreId << " around pc " << pc << " of " << events.size() << ":\n";
|
||||||
|
for (size_t index = begin; index < end; ++index) {
|
||||||
|
os << (index == pc ? " => " : " ") << formatCommunicationEvent(events[index]) << "\n";
|
||||||
|
os << " op: " << formatOperationSummary(events[index].op) << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printCommunicationDeadlockReport(const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
ArrayRef<int64_t> cycle) {
|
||||||
|
llvm::errs() << "\n=== PIM static communication deadlock report ===\n";
|
||||||
|
llvm::errs() << "wait cycle:";
|
||||||
|
for (int64_t coreId : cycle)
|
||||||
|
llvm::errs() << " " << coreId;
|
||||||
|
if (!cycle.empty())
|
||||||
|
llvm::errs() << " -> " << cycle.front();
|
||||||
|
llvm::errs() << "\n\nblocked heads:\n";
|
||||||
|
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
const CommunicationEvent& event = eventsIt->second[pcIt->second];
|
||||||
|
llvm::errs() << " " << formatCommunicationEvent(event) << "\n";
|
||||||
|
llvm::errs() << " loc: " << formatLocation(event.op->getLoc()) << "\n";
|
||||||
|
llvm::errs() << " op : " << formatOperationSummary(event.op) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "\npeer counterpart probes:\n";
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
printCounterpartProbe(llvm::errs(), coreEvents, programCounters, eventsIt->second[pcIt->second]);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::errs() << "\nlocal communication streams:\n";
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (pcIt == programCounters.end())
|
||||||
|
continue;
|
||||||
|
printCommunicationWindow(llvm::errs(), coreEvents, coreId, pcIt->second);
|
||||||
|
}
|
||||||
|
llvm::errs() << "=== end PIM static communication deadlock report ===\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void emitCommunicationDeadlockCycle(ModuleOp moduleOp,
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters,
|
||||||
|
ArrayRef<int64_t> cycle) {
|
||||||
|
printCommunicationDeadlockReport(coreEvents, programCounters, cycle);
|
||||||
|
|
||||||
|
auto diagnostic = moduleOp.emitError()
|
||||||
|
<< "PIM communication deadlock check found a blocking send/receive cycle while statically simulating the "
|
||||||
|
"expanded per-core communication streams; see the PIM static communication deadlock report above";
|
||||||
|
|
||||||
|
for (int64_t coreId : cycle) {
|
||||||
|
auto eventsIt = coreEvents.find(coreId);
|
||||||
|
auto pcIt = programCounters.find(coreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& event = eventsIt->second[pcIt->second];
|
||||||
|
Diagnostic& note = diagnostic.attachNote(event.op->getLoc());
|
||||||
|
note << formatCommunicationEvent(event);
|
||||||
|
if (!event.materializer.empty())
|
||||||
|
note << " emitted by " << event.materializer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>> findCommunicationWaitCycle(
|
||||||
|
const DenseMap<int64_t, CommunicationEventVector>& coreEvents,
|
||||||
|
const DenseMap<int64_t, size_t>& programCounters) {
|
||||||
|
for (const auto& [startCoreId, events] : coreEvents) {
|
||||||
|
auto startPcIt = programCounters.find(startCoreId);
|
||||||
|
if (startPcIt == programCounters.end() || startPcIt->second >= events.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
DenseMap<int64_t, size_t> positionInPath;
|
||||||
|
SmallVector<int64_t, 8> path;
|
||||||
|
int64_t currentCoreId = startCoreId;
|
||||||
|
while (true) {
|
||||||
|
auto eventsIt = coreEvents.find(currentCoreId);
|
||||||
|
auto pcIt = programCounters.find(currentCoreId);
|
||||||
|
if (eventsIt == coreEvents.end() || pcIt == programCounters.end() || pcIt->second >= eventsIt->second.size())
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto positionIt = positionInPath.find(currentCoreId);
|
||||||
|
if (positionIt != positionInPath.end()) {
|
||||||
|
SmallVector<int64_t> cycle;
|
||||||
|
for (size_t index = positionIt->second; index < path.size(); ++index)
|
||||||
|
cycle.push_back(path[index]);
|
||||||
|
return cycle;
|
||||||
|
}
|
||||||
|
|
||||||
|
positionInPath[currentCoreId] = path.size();
|
||||||
|
path.push_back(currentCoreId);
|
||||||
|
currentCoreId = eventsIt->second[pcIt->second].peerCoreId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyNoStaticCommunicationDeadlock(ModuleOp moduleOp,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
DenseMap<int64_t, CommunicationEventVector> coreEvents;
|
||||||
|
bool hasFailure = false;
|
||||||
|
|
||||||
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||||
|
if (funcOp.isExternal())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||||
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||||
|
int64_t coreId = coreOp.getCoreId();
|
||||||
|
if (failed(appendCoreCommunicationEvents(
|
||||||
|
coreOp.getBody().front(), coreId, StaticValueKnowledge {}, coreEvents[coreId], diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||||
|
SmallVector<int32_t> coreIds = getBatchCoreIds(coreBatchOp);
|
||||||
|
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||||
|
for (size_t lane = 0; lane < laneCount; ++lane) {
|
||||||
|
StaticValueKnowledge laneKnowledge;
|
||||||
|
laneKnowledge.indexValues[coreBatchOp.getLaneArgument()] = static_cast<int64_t>(lane);
|
||||||
|
for (unsigned inputIndex = 0; inputIndex < coreBatchOp.getInputs().size(); ++inputIndex)
|
||||||
|
laneKnowledge.aliases[coreBatchOp.getInputArgument(inputIndex)] = coreBatchOp.getInputs()[inputIndex];
|
||||||
|
|
||||||
|
SmallVector<int32_t> laneCoreIds = getLaneChunkCoreIds(coreIds, laneCount, static_cast<unsigned>(lane));
|
||||||
|
for (int32_t coreId : laneCoreIds) {
|
||||||
|
if (failed(appendCoreCommunicationEvents(
|
||||||
|
coreBatchOp.getBody().front(), coreId, laneKnowledge, coreEvents[coreId], diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasFailure)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
DenseMap<int64_t, size_t> programCounters;
|
||||||
|
for (const auto& [coreId, events] : coreEvents)
|
||||||
|
programCounters[coreId] = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
bool madeProgress = false;
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
size_t pc = programCounters[coreId];
|
||||||
|
if (pc >= events.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& event = events[pc];
|
||||||
|
auto peerEventsIt = coreEvents.find(event.peerCoreId);
|
||||||
|
if (peerEventsIt == coreEvents.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t peerPc = programCounters[event.peerCoreId];
|
||||||
|
if (peerPc >= peerEventsIt->second.size())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const CommunicationEvent& peerEvent = peerEventsIt->second[peerPc];
|
||||||
|
if (!areMatchedCommunicationEvents(event, peerEvent))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
++programCounters[coreId];
|
||||||
|
++programCounters[event.peerCoreId];
|
||||||
|
madeProgress = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (madeProgress)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
bool allDone = true;
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
if (programCounters[coreId] < events.size()) {
|
||||||
|
allDone = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allDone)
|
||||||
|
return success();
|
||||||
|
|
||||||
|
auto cycle = findCommunicationWaitCycle(coreEvents, programCounters);
|
||||||
|
if (succeeded(cycle)) {
|
||||||
|
emitCommunicationDeadlockCycle(moduleOp, coreEvents, programCounters, *cycle);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto diagnostic = moduleOp.emitError()
|
||||||
|
<< "PIM communication deadlock check stalled without finding a closed wait cycle; this usually means a "
|
||||||
|
"send/receive peer is missing or ordered after a finished core";
|
||||||
|
for (const auto& [coreId, events] : coreEvents) {
|
||||||
|
size_t pc = programCounters[coreId];
|
||||||
|
if (pc >= events.size())
|
||||||
|
continue;
|
||||||
|
const CommunicationEvent& event = events[pc];
|
||||||
|
diagnostic.attachNote(event.op->getLoc()) << formatCommunicationEvent(event);
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||||
|
|
||||||
@@ -212,11 +690,18 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasFailure = false;
|
||||||
|
if (pimDetectCommunicationDeadlock && failed(verifyNoStaticCommunicationDeadlock(moduleOp, diagnostics)))
|
||||||
|
hasFailure = true;
|
||||||
|
|
||||||
if (diagnostics.hasFailure()) {
|
if (diagnostics.hasFailure()) {
|
||||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||||
signalPassFailure();
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hasFailure)
|
||||||
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def SpatTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatCompute : SpatOp<"compute",
|
class SpatComputeLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||||
[SingleBlock, AttrSizedOperandSegments,
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Compute region with attached constant weights";
|
let summary = "Compute region with attached constant weights";
|
||||||
@@ -42,6 +42,12 @@ def SpatCompute : SpatOp<"compute",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasFolder = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatGraphCompute : SpatComputeLikeBase<"graph_compute"> {
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
@@ -50,16 +56,26 @@ def SpatCompute : SpatOp<"compute",
|
|||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatGraphCompute>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasFolder = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
def SpatScheduledCompute : SpatComputeLikeBase<"scheduled_compute"> {
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatScheduledCompute>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
class SpatComputeBatchLikeBase<string mnemonic> : SpatOp<mnemonic,
|
||||||
[SingleBlock, AttrSizedOperandSegments,
|
[SingleBlock, AttrSizedOperandSegments,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||||
@@ -76,6 +92,11 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatGraphComputeBatch : SpatComputeBatchLikeBase<"graph_compute_batch"> {
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
@@ -86,21 +107,33 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatGraphComputeBatch>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
let hasVerifier = 1;
|
def SpatScheduledComputeBatch : SpatComputeBatchLikeBase<"scheduled_compute_batch"> {
|
||||||
let hasCustomAssemblyFormat = 1;
|
let extraClassDeclaration = [{
|
||||||
|
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||||
|
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||||
|
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatScheduledComputeBatch>>
|
||||||
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||||
Pure,
|
Pure,
|
||||||
Terminator,
|
Terminator,
|
||||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||||
HasParent<"SpatComputeBatch">,
|
|
||||||
] # GraphRegionNoTerminator.traits> {
|
] # GraphRegionNoTerminator.traits> {
|
||||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
let summary = "Parallel combining terminator for resultful Spatial compute batches";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$region);
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
@@ -159,6 +192,90 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Planning
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SpatConv2DPlanOp : SpatOp<"conv2d_plan", []> {
|
||||||
|
let summary = "Structured Conv2D planning op that preserves logical ONNX geometry";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
SpatTensor:$weight,
|
||||||
|
Optional<SpatTensor>:$bias,
|
||||||
|
DenseI64ArrayAttr:$pads,
|
||||||
|
DenseI64ArrayAttr:$strides,
|
||||||
|
DenseI64ArrayAttr:$dilations,
|
||||||
|
I64Attr:$group,
|
||||||
|
StrAttr:$logicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatReluPlanOp : SpatOp<"relu_plan", []> {
|
||||||
|
let summary = "Layout-aware ReLU planning op";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
StrAttr:$logicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatBlueprintOp : SpatOp<"blueprint", []> {
|
||||||
|
let summary = "Blueprint for assembling logical tensors from published fragments";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
Variadic<SpatTensor>:$fragments,
|
||||||
|
StrAttr:$logicalLayout,
|
||||||
|
StrAttr:$physicalLayout,
|
||||||
|
DenseI64ArrayAttr:$fragmentOffsets,
|
||||||
|
DenseI64ArrayAttr:$fragmentSizes,
|
||||||
|
StrAttr:$indexMap,
|
||||||
|
OptionalAttr<StrAttr>:$mode,
|
||||||
|
OptionalAttr<DenseI64ArrayAttr>:$fragmentOperandIndices,
|
||||||
|
OptionalAttr<DenseI64ArrayAttr>:$fragmentSourceOffsets,
|
||||||
|
OptionalAttr<DenseI64ArrayAttr>:$fragmentStrides,
|
||||||
|
OptionalAttr<StrAttr>:$conflictPolicy,
|
||||||
|
OptionalAttr<StrAttr>:$coveragePolicy
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatMaterializeLayoutOp : SpatOp<"materialize_layout", []> {
|
||||||
|
let summary = "Explicit layout conversion or materialization barrier";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$input,
|
||||||
|
StrAttr:$logicalLayout,
|
||||||
|
StrAttr:$sourcePhysicalLayout,
|
||||||
|
StrAttr:$targetPhysicalLayout
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Communication
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
@@ -29,11 +29,19 @@ std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
if (auto compute = dyn_cast<SpatGraphCompute>(op)) {
|
||||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
if (auto compute = dyn_cast<SpatScheduledCompute>(op)) {
|
||||||
|
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatGraphComputeBatch>(op)) {
|
||||||
|
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cast<SpatScheduledComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
}
|
}
|
||||||
|
|
||||||
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
||||||
@@ -47,116 +55,205 @@ CrossbarWeightSet collectCrossbarWeights(Region& body) {
|
|||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
template <typename ComputeOpTy>
|
||||||
|
std::optional<BlockArgument> getComputeWeightArgument(ComputeOpTy compute, unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
return getBlockArgument(compute.getBody(), idx);
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), getWeights().size() + idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
template <typename ComputeOpTy>
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
std::optional<BlockArgument> getComputeInputArgument(ComputeOpTy compute, unsigned idx) {
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
return getBlockArgument(compute.getBody(), compute.getWeights().size() + idx);
|
||||||
return {
|
|
||||||
{*existing, *getWeightArgument(index)}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned weightCount = getWeights().size();
|
template <typename ComputeOpTy>
|
||||||
unsigned inputCount = getInputs().size();
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
insertComputeWeight(ComputeOpTy compute, unsigned idx, Value weight, Location loc) {
|
||||||
|
if (auto existing = llvm::find(compute.getWeights(), weight); existing != compute.getWeights().end()) {
|
||||||
|
auto index = std::distance(compute.getWeights().begin(), existing);
|
||||||
|
return {{*existing, *getComputeWeightArgument(compute, index)}};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned weightCount = compute.getWeights().size();
|
||||||
|
unsigned inputCount = compute.getInputs().size();
|
||||||
|
compute.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
compute.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
|
auto blockArg = insertBlockArgument(compute.getBody(), idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
return std::make_tuple(compute.getOperation()->getOperand(idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
template <typename ComputeBatchOpTy>
|
||||||
unsigned weightCount = getWeights().size();
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
unsigned inputCount = getInputs().size();
|
insertComputeBatchWeight(ComputeBatchOpTy batch, unsigned idx, Value weight, Location loc) {
|
||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
if (auto existing = llvm::find(batch.getWeights(), weight); existing != batch.getWeights().end()) {
|
||||||
|
auto index = std::distance(batch.getWeights().begin(), existing);
|
||||||
|
return {{*existing, *batch.getWeightArgument(index)}};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned weightCount = batch.getWeights().size();
|
||||||
|
unsigned inputCount = batch.getInputs().size();
|
||||||
|
batch.getOperation()->insertOperands(idx, ValueRange {weight});
|
||||||
setComputeOperandSegmentSizes(
|
setComputeOperandSegmentSizes(
|
||||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
batch.getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||||
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
|
|
||||||
|
auto blockArg = insertBlockArgument(batch.getBody(), 1 + idx, weight.getType(), loc);
|
||||||
if (!blockArg)
|
if (!blockArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(batch.getOperation()->getOperand(idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
template <typename ComputeOpTy>
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
insertComputeInput(ComputeOpTy compute, unsigned idx, Value input, Location loc) {
|
||||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
unsigned weightCount = compute.getWeights().size();
|
||||||
if (idx > getNumResults())
|
unsigned inputCount = compute.getInputs().size();
|
||||||
return failure();
|
compute.getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
rewriter.setInsertionPoint(getOperation());
|
compute.getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
auto blockArg = insertBlockArgument(compute.getBody(), weightCount + idx, input.getType(), loc);
|
||||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
if (!blockArg)
|
||||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
return std::nullopt;
|
||||||
newCompute->setAttrs((*this)->getAttrs());
|
return std::make_tuple(compute.getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
|
||||||
static_cast<int32_t>(newCompute.getWeights().size()),
|
|
||||||
static_cast<int32_t>(newCompute.getInputs().size()));
|
|
||||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
|
||||||
getResult(oldResultIdx)
|
|
||||||
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
|
||||||
rewriter.eraseOp(getOperation());
|
|
||||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
template <typename ComputeOpTy>
|
||||||
|
void setComputeAsmBlockArgumentNames(ComputeOpTy compute, Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
if (region.empty())
|
if (region.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
for (unsigned index = 0; index < compute.getWeights().size(); ++index)
|
||||||
if (auto weightArg = getWeightArgument(index))
|
if (auto weightArg = compute.getWeightArgument(index))
|
||||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||||
|
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
for (unsigned index = 0; index < compute.getInputs().size(); ++index)
|
||||||
if (auto inputArg = getInputArgument(index))
|
if (auto inputArg = compute.getInputArgument(index))
|
||||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
template <typename ComputeOpTy>
|
||||||
|
FailureOr<std::tuple<OpResult, ComputeOpTy>>
|
||||||
|
insertComputeOutput(ComputeOpTy compute, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > compute.getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
rewriter.setInsertionPoint(compute.getOperation());
|
||||||
|
SmallVector<Type> resultTypes(compute.getResultTypes().begin(), compute.getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newCompute =
|
||||||
|
ComputeOpTy::create(rewriter, compute.getLoc(), TypeRange(resultTypes), compute.getWeights(), compute.getInputs());
|
||||||
|
newCompute->setAttrs(compute->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||||
|
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(compute.getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < compute.getNumResults(); ++oldResultIdx)
|
||||||
|
compute.getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(compute.getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
FailureOr<std::tuple<OpResult, BlockArgument, ComputeBatchOpTy>>
|
||||||
|
insertComputeBatchOutput(ComputeBatchOpTy batch, RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
if (idx > batch.getNumResults())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(batch.getOperation());
|
||||||
|
SmallVector<Type> resultTypes(batch.getResultTypes().begin(), batch.getResultTypes().end());
|
||||||
|
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||||
|
auto newBatch =
|
||||||
|
ComputeBatchOpTy::create(rewriter, batch.getLoc(), TypeRange(resultTypes), batch.getLaneCountAttr(), batch.getWeights(), batch.getInputs());
|
||||||
|
newBatch->setAttrs(batch->getAttrs());
|
||||||
|
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||||
|
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||||
|
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||||
|
rewriter.inlineRegionBefore(batch.getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||||
|
if (newBatch.getBody().empty()) {
|
||||||
|
rewriter.eraseOp(newBatch);
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||||
|
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||||
|
for (unsigned oldResultIdx = 0; oldResultIdx < batch.getNumResults(); ++oldResultIdx)
|
||||||
|
batch.getResult(oldResultIdx)
|
||||||
|
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||||
|
rewriter.eraseOp(batch.getOperation());
|
||||||
|
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isGraphComputeLike(Operation* op) { return isa<SpatGraphCompute, SpatGraphComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isGraphBatchComputeLike(Operation* op) { return isa<SpatGraphComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isScheduledComputeLike(Operation* op) { return isa<SpatScheduledCompute, SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isScheduledBatchComputeLike(Operation* op) { return isa<SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
bool isAnySpatialComputeLike(Operation* op) {
|
||||||
|
return isa<SpatGraphCompute, SpatGraphComputeBatch, SpatScheduledCompute, SpatScheduledComputeBatch>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isAnySpatialComputeBatchLike(Operation* op) { return isa<SpatGraphComputeBatch, SpatScheduledComputeBatch>(op); }
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatGraphCompute::getWeightArgument(unsigned idx) { return getComputeWeightArgument(*this, idx); }
|
||||||
|
std::optional<BlockArgument> SpatGraphCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>> SpatGraphCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
return insertComputeInput(*this, idx, input, loc);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatGraphCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, SpatGraphCompute>>
|
||||||
|
SpatGraphCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatGraphCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatScheduledCompute::getWeightArgument(unsigned idx) {
|
||||||
|
return getComputeWeightArgument(*this, idx);
|
||||||
|
}
|
||||||
|
std::optional<BlockArgument> SpatScheduledCompute::getInputArgument(unsigned idx) { return getComputeInputArgument(*this, idx); }
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
return insertComputeInput(*this, idx, input, loc);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatScheduledCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, SpatScheduledCompute>>
|
||||||
|
SpatScheduledCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatScheduledCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
return getBlockArgument(getBody(), 1 + idx);
|
return getBlockArgument(getBody(), 1 + idx);
|
||||||
}
|
}
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getInputArgument(unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
}
|
}
|
||||||
|
std::optional<BlockArgument> SpatGraphComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
|
||||||
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>>
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
SpatGraphComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
|
||||||
return {
|
|
||||||
{*existing, *getWeightArgument(index)}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
unsigned weightCount = getWeights().size();
|
SpatGraphComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
unsigned inputCount = getInputs().size();
|
|
||||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
|
||||||
setComputeOperandSegmentSizes(
|
|
||||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
|
||||||
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
|
|
||||||
if (!blockArg)
|
|
||||||
return std::nullopt;
|
|
||||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
|
||||||
unsigned weightCount = getWeights().size();
|
unsigned weightCount = getWeights().size();
|
||||||
unsigned inputCount = getInputs().size();
|
unsigned inputCount = getInputs().size();
|
||||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
@@ -167,52 +264,68 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
CrossbarWeightSet SpatGraphComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatGraphComputeBatch>>
|
||||||
|
SpatGraphComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
|
||||||
if (idx > getNumResults())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(getOperation());
|
|
||||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
|
||||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
|
||||||
auto newBatch =
|
|
||||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
|
||||||
newBatch->setAttrs((*this)->getAttrs());
|
|
||||||
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
|
||||||
static_cast<int32_t>(newBatch.getWeights().size()),
|
|
||||||
static_cast<int32_t>(newBatch.getInputs().size()));
|
|
||||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
|
||||||
if (newBatch.getBody().empty()) {
|
|
||||||
rewriter.eraseOp(newBatch);
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
void SpatGraphComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
|
||||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
|
||||||
getResult(oldResultIdx)
|
|
||||||
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
|
||||||
rewriter.eraseOp(getOperation());
|
|
||||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
|
||||||
if (region.empty())
|
if (region.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (auto laneArg = getLaneArgument())
|
if (auto laneArg = getLaneArgument())
|
||||||
setNameFn(*laneArg, "lane");
|
setNameFn(*laneArg, "lane");
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
|
auto outputArg = getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
continue;
|
||||||
|
if (index == 0) {
|
||||||
|
setNameFn(*outputArg, "out");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||||
if (auto weightArg = getWeightArgument(index))
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getWeightArgument(unsigned idx) {
|
||||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
return getBlockArgument(getBody(), 1 + idx);
|
||||||
|
}
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getInputArgument(unsigned idx) {
|
||||||
if (auto inputArg = getInputArgument(index))
|
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
}
|
||||||
|
std::optional<BlockArgument> SpatScheduledComputeBatch::getOutputArgument(unsigned idx) {
|
||||||
|
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
|
return insertComputeBatchWeight(*this, idx, weight, loc);
|
||||||
|
}
|
||||||
|
std::optional<std::tuple<Value, BlockArgument>>
|
||||||
|
SpatScheduledComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||||
|
unsigned weightCount = getWeights().size();
|
||||||
|
unsigned inputCount = getInputs().size();
|
||||||
|
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||||
|
setComputeOperandSegmentSizes(
|
||||||
|
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||||
|
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||||
|
if (!blockArg)
|
||||||
|
return std::nullopt;
|
||||||
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
|
}
|
||||||
|
CrossbarWeightSet SpatScheduledComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatScheduledComputeBatch>>
|
||||||
|
SpatScheduledComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
|
return insertComputeBatchOutput(*this, rewriter, idx, type, loc);
|
||||||
|
}
|
||||||
|
void SpatScheduledComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||||
|
if (region.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (auto laneArg = getLaneArgument())
|
||||||
|
setNameFn(*laneArg, "lane");
|
||||||
|
setComputeAsmBlockArgumentNames(*this, region, setNameFn);
|
||||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||||
auto outputArg = getOutputArgument(index);
|
auto outputArg = getOutputArgument(index);
|
||||||
if (!outputArg)
|
if (!outputArg)
|
||||||
@@ -231,7 +344,11 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
|||||||
builder.createBlock(bodyRegion);
|
builder.createBlock(bodyRegion);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
OpResult SpatInParallelOp::getParentResult(int64_t idx) {
|
||||||
|
Operation* parent = getOperation()->getParentOp();
|
||||||
|
assert(isAnySpatialComputeBatchLike(parent) && "expected Spatial compute batch parent");
|
||||||
|
return parent->getResult(idx);
|
||||||
|
}
|
||||||
|
|
||||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||||
|
|
||||||
|
|||||||
@@ -26,3 +26,19 @@
|
|||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp.inc"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
bool isGraphComputeLike(mlir::Operation* op);
|
||||||
|
bool isGraphBatchComputeLike(mlir::Operation* op);
|
||||||
|
bool isScheduledComputeLike(mlir::Operation* op);
|
||||||
|
bool isScheduledBatchComputeLike(mlir::Operation* op);
|
||||||
|
bool isAnySpatialComputeLike(mlir::Operation* op);
|
||||||
|
bool isAnySpatialComputeBatchLike(mlir::Operation* op);
|
||||||
|
|
||||||
|
using SpatCompute = SpatGraphCompute;
|
||||||
|
using SpatComputeBatch = SpatGraphComputeBatch;
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|||||||
return parser.getBuilder().getI32IntegerAttr(value);
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ParseResult parseBareStringAttr(OpAsmParser& parser, StringAttr& attr) {
|
||||||
|
StringRef value;
|
||||||
|
if (parser.parseKeyword(&value))
|
||||||
|
return failure();
|
||||||
|
attr = parser.getBuilder().getStringAttr(value);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||||
printer << "(";
|
printer << "(";
|
||||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||||
@@ -115,6 +123,254 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
void printComputeLikeOp(ComputeOpTy op, OpAsmPrinter& printer) {
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(op.getWeights().size());
|
||||||
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||||
|
auto weightArg = op.getWeightArgument(index);
|
||||||
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(op.getInputs().size());
|
||||||
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||||
|
auto inputArg = op.getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||||
|
|
||||||
|
if (auto coreIdAttr = op->template getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
|
printer << " coreId " << coreIdAttr.getInt();
|
||||||
|
printer << " crossbarWeights " << collectDistinctCrossbarWeights(op.getOperation()).size();
|
||||||
|
|
||||||
|
printer.printOptionalAttrDict(op->getAttrs(), {op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||||
|
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||||
|
printer << " -> ";
|
||||||
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||||
|
printer << " ";
|
||||||
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
ParseResult parseComputeLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
SmallVector<Type> weightTypes;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
|
int32_t coreId = 0;
|
||||||
|
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
|
if (hasCoreId && parser.parseInteger(coreId))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (weights.size() != weightTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
|
if (inputs.size() != inputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (inputArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreId cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
result.addAttribute(
|
||||||
|
"operandSegmentSizes",
|
||||||
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
|
if (hasCoreId)
|
||||||
|
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
||||||
|
|
||||||
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||||
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
applyArgumentTypes(weightTypes, weightArgs);
|
||||||
|
applyArgumentTypes(inputTypes, inputArgs);
|
||||||
|
llvm::append_range(regionArgs, weightArgs);
|
||||||
|
llvm::append_range(regionArgs, inputArgs);
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
void printComputeBatchLikeOp(ComputeBatchOpTy op, OpAsmPrinter& printer) {
|
||||||
|
auto laneArg = op.getLaneArgument();
|
||||||
|
SmallVector<Value> weightArgs;
|
||||||
|
weightArgs.reserve(op.getWeights().size());
|
||||||
|
for (unsigned index = 0; index < op.getWeights().size(); ++index) {
|
||||||
|
auto weightArg = op.getWeightArgument(index);
|
||||||
|
if (!weightArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
weightArgs.push_back(*weightArg);
|
||||||
|
}
|
||||||
|
SmallVector<Value> inputArgs;
|
||||||
|
inputArgs.reserve(op.getInputs().size());
|
||||||
|
for (unsigned index = 0; index < op.getInputs().size(); ++index) {
|
||||||
|
auto inputArg = op.getInputArgument(index);
|
||||||
|
if (!inputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
inputArgs.push_back(*inputArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<BlockArgument> outputArgs;
|
||||||
|
if (!laneArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
if (op.getNumResults() != 0) {
|
||||||
|
outputArgs.reserve(op.getNumResults());
|
||||||
|
for (unsigned index = 0; index < op.getNumResults(); ++index) {
|
||||||
|
auto outputArg = op.getOutputArgument(index);
|
||||||
|
if (!outputArg)
|
||||||
|
return printer.printGenericOp(op.getOperation(), /*printOpName=*/false);
|
||||||
|
outputArgs.push_back(*outputArg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printer << " ";
|
||||||
|
printer.printOperand(*laneArg);
|
||||||
|
printer << " = 0 to " << op.getLaneCount();
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, weightArgs, op.getWeights(), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printBoundValueList(printer, inputArgs, op.getInputs(), ListDelimiter::Paren);
|
||||||
|
if (op.getNumResults() != 0) {
|
||||||
|
printer << " shared_outs";
|
||||||
|
printBlockArgumentList(printer, outputArgs);
|
||||||
|
}
|
||||||
|
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({op.getOperation(), 0, op.getLaneCount()}).size();
|
||||||
|
if (auto coreIdsAttr = op->template getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||||
|
printer << " coreIds ";
|
||||||
|
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||||
|
}
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
op->getAttrs(),
|
||||||
|
{op.getLaneCountAttrName().getValue(), op.getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getWeights()), ListDelimiter::Square);
|
||||||
|
printer << " ";
|
||||||
|
printCompressedTypeList(printer, TypeRange(op.getInputs()), ListDelimiter::Paren);
|
||||||
|
printer << " -> ";
|
||||||
|
printCompressedTypeSequence(printer, op.getResultTypes());
|
||||||
|
printer << " ";
|
||||||
|
printer.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
ParseResult parseComputeBatchLikeOp(OpAsmParser& parser, OperationState& result) {
|
||||||
|
int64_t lowerBound = 0;
|
||||||
|
int32_t laneCount = 0;
|
||||||
|
OpAsmParser::Argument laneArg;
|
||||||
|
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
SmallVector<Type> weightTypes;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
|
||||||
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||||
|
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||||
|
return failure();
|
||||||
|
if (lowerBound != 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
|
return failure();
|
||||||
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||||
|
if (parseBlockArgumentList(parser, outputArgs))
|
||||||
|
return failure();
|
||||||
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||||
|
return failure();
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||||
|
|| parseCompressedRepeatedList(
|
||||||
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (weights.size() != weightTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
|
if (weightArgs.size() != weights.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||||
|
if (inputs.size() != inputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (inputArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (outputArgs.size() != outputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"number of shared output bindings and result types must match");
|
||||||
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreIds cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
|
auto& builder = parser.getBuilder();
|
||||||
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||||
|
result.addAttribute(
|
||||||
|
"operandSegmentSizes",
|
||||||
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
|
if (hasCoreIds)
|
||||||
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||||
|
|
||||||
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||||
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
applyBatchRegionArgumentTypes(
|
||||||
|
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
void SpatYieldOp::print(OpAsmPrinter& printer) {
|
||||||
@@ -218,260 +474,146 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
void SpatBlueprintOp::print(OpAsmPrinter& printer) {
|
||||||
SmallVector<Value> weightArgs;
|
SmallVector<Value> operands {getInput()};
|
||||||
weightArgs.reserve(getWeights().size());
|
llvm::append_range(operands, getFragments());
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
|
||||||
auto weightArg = getWeightArgument(index);
|
|
||||||
if (!weightArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
weightArgs.push_back(*weightArg);
|
|
||||||
}
|
|
||||||
SmallVector<Value> inputArgs;
|
|
||||||
inputArgs.reserve(getInputs().size());
|
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
|
||||||
auto inputArg = getInputArgument(index);
|
|
||||||
if (!inputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
inputArgs.push_back(*inputArg);
|
|
||||||
}
|
|
||||||
|
|
||||||
printer << " ";
|
printer << " fragments";
|
||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
printer << " ";
|
printer << " layout " << getLogicalLayout();
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
printer << " physical " << getPhysicalLayout();
|
||||||
|
printer << " offsets ";
|
||||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
printCompressedIntegerList(printer, getFragmentOffsets());
|
||||||
printer << " coreId " << coreIdAttr.getInt();
|
printer << " sizes ";
|
||||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
printCompressedIntegerList(printer, getFragmentSizes());
|
||||||
|
printer << " map " << getIndexMap();
|
||||||
|
if (std::optional<StringRef> mode = getMode())
|
||||||
|
printer << " mode " << *mode;
|
||||||
|
if (std::optional<ArrayRef<int64_t>> operandIndices = getFragmentOperandIndices()) {
|
||||||
|
printer << " operandIndices ";
|
||||||
|
printCompressedIntegerList(printer, *operandIndices);
|
||||||
|
}
|
||||||
|
if (std::optional<ArrayRef<int64_t>> sourceOffsets = getFragmentSourceOffsets()) {
|
||||||
|
printer << " sourceOffsets ";
|
||||||
|
printCompressedIntegerList(printer, *sourceOffsets);
|
||||||
|
}
|
||||||
|
if (std::optional<ArrayRef<int64_t>> strides = getFragmentStrides()) {
|
||||||
|
printer << " strides ";
|
||||||
|
printCompressedIntegerList(printer, *strides);
|
||||||
|
}
|
||||||
|
if (std::optional<StringRef> conflictPolicy = getConflictPolicy())
|
||||||
|
printer << " conflict " << *conflictPolicy;
|
||||||
|
if (std::optional<StringRef> coveragePolicy = getCoveragePolicy())
|
||||||
|
printer << " coverage " << *coveragePolicy;
|
||||||
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
{getLogicalLayoutAttrName().getValue(),
|
||||||
|
getPhysicalLayoutAttrName().getValue(),
|
||||||
|
getFragmentOffsetsAttrName().getValue(),
|
||||||
|
getFragmentSizesAttrName().getValue(),
|
||||||
|
getIndexMapAttrName().getValue(),
|
||||||
|
getModeAttrName().getValue(),
|
||||||
|
getFragmentOperandIndicesAttrName().getValue(),
|
||||||
|
getFragmentSourceOffsetsAttrName().getValue(),
|
||||||
|
getFragmentStridesAttrName().getValue(),
|
||||||
|
getConflictPolicyAttrName().getValue(),
|
||||||
|
getCoveragePolicyAttrName().getValue()});
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
printCompressedTypeList(printer, TypeRange(operands), ListDelimiter::Paren);
|
||||||
printer << " ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
||||||
printer << " -> ";
|
printer << " -> ";
|
||||||
printCompressedTypeSequence(printer, getResultTypes());
|
printer.printType(getOutput().getType());
|
||||||
printer << " ";
|
|
||||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatBlueprintOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
SmallVector<Type> operandTypes;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
Type outputType;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
StringAttr logicalLayout;
|
||||||
SmallVector<Type> weightTypes;
|
StringAttr physicalLayout;
|
||||||
SmallVector<Type> inputTypes;
|
StringAttr indexMap;
|
||||||
SmallVector<Type> outputTypes;
|
StringAttr mode;
|
||||||
int32_t crossbarWeightCount = 0;
|
StringAttr conflictPolicy;
|
||||||
int32_t coreId = 0;
|
StringAttr coveragePolicy;
|
||||||
|
SmallVector<int64_t> fragmentOffsets;
|
||||||
|
SmallVector<int64_t> fragmentSizes;
|
||||||
|
SmallVector<int64_t> fragmentOperandIndices;
|
||||||
|
SmallVector<int64_t> fragmentSourceOffsets;
|
||||||
|
SmallVector<int64_t> fragmentStrides;
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
if (parser.parseKeyword("fragments")
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)
|
||||||
|
|| parser.parseKeyword("layout") || parseBareStringAttr(parser, logicalLayout)
|
||||||
|
|| parser.parseKeyword("physical") || parseBareStringAttr(parser, physicalLayout)
|
||||||
|
|| parser.parseKeyword("offsets") || parseCompressedIntegerList(parser, fragmentOffsets)
|
||||||
|
|| parser.parseKeyword("sizes") || parseCompressedIntegerList(parser, fragmentSizes)
|
||||||
|
|| parser.parseKeyword("map") || parseBareStringAttr(parser, indexMap))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
if (succeeded(parser.parseOptionalKeyword("mode")) && parseBareStringAttr(parser, mode))
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("operandIndices"))
|
||||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
&& parseCompressedIntegerList(parser, fragmentOperandIndices))
|
||||||
if (hasCoreId && parser.parseInteger(coreId))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("sourceOffsets"))
|
||||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
&& parseCompressedIntegerList(parser, fragmentSourceOffsets))
|
||||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("strides")) && parseCompressedIntegerList(parser, fragmentStrides))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("conflict")) && parseBareStringAttr(parser, conflictPolicy))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("coverage")) && parseBareStringAttr(parser, coveragePolicy))
|
||||||
return failure();
|
return failure();
|
||||||
(void) crossbarWeightCount;
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|| parseCompressedRepeatedList(
|
|| parseCompressedRepeatedList(
|
||||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
parser, ListDelimiter::Paren, operandTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|| parseCompressedRepeatedList(
|
|| parser.parseArrow() || parser.parseType(outputType))
|
||||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
if (operands.empty())
|
||||||
if (weights.size() != weightTypes.size())
|
return parser.emitError(parser.getCurrentLocation(), "spat.blueprint requires at least one fragment operand");
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
if (operands.size() != operandTypes.size())
|
||||||
if (weightArgs.size() != weights.size())
|
return parser.emitError(parser.getCurrentLocation(), "number of fragment operands and types must match");
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
||||||
if (inputs.size() != inputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
||||||
if (inputArgs.size() != inputs.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
||||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"coreId cannot be specified both positionally and in attr-dict");
|
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
auto& builder = parser.getBuilder();
|
||||||
result.addAttribute(
|
result.addAttribute("logicalLayout", logicalLayout);
|
||||||
"operandSegmentSizes",
|
result.addAttribute("physicalLayout", physicalLayout);
|
||||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
result.addAttribute("fragmentOffsets", builder.getDenseI64ArrayAttr(fragmentOffsets));
|
||||||
if (hasCoreId)
|
result.addAttribute("fragmentSizes", builder.getDenseI64ArrayAttr(fragmentSizes));
|
||||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId));
|
result.addAttribute("indexMap", indexMap);
|
||||||
|
if (mode)
|
||||||
|
result.addAttribute("mode", mode);
|
||||||
|
if (!fragmentOperandIndices.empty())
|
||||||
|
result.addAttribute("fragmentOperandIndices", builder.getDenseI64ArrayAttr(fragmentOperandIndices));
|
||||||
|
if (!fragmentSourceOffsets.empty())
|
||||||
|
result.addAttribute("fragmentSourceOffsets", builder.getDenseI64ArrayAttr(fragmentSourceOffsets));
|
||||||
|
if (!fragmentStrides.empty())
|
||||||
|
result.addAttribute("fragmentStrides", builder.getDenseI64ArrayAttr(fragmentStrides));
|
||||||
|
if (conflictPolicy)
|
||||||
|
result.addAttribute("conflictPolicy", conflictPolicy);
|
||||||
|
if (coveragePolicy)
|
||||||
|
result.addAttribute("coveragePolicy", coveragePolicy);
|
||||||
|
|
||||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
if (parser.resolveOperands(operands, operandTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
||||||
return failure();
|
return failure();
|
||||||
result.addTypes(outputTypes);
|
result.addTypes(outputType);
|
||||||
|
return success();
|
||||||
Region* body = result.addRegion();
|
|
||||||
applyArgumentTypes(weightTypes, weightArgs);
|
|
||||||
applyArgumentTypes(inputTypes, inputArgs);
|
|
||||||
llvm::append_range(regionArgs, weightArgs);
|
|
||||||
llvm::append_range(regionArgs, inputArgs);
|
|
||||||
return parser.parseRegion(*body, regionArgs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
void SpatGraphCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||||
auto laneArg = getLaneArgument();
|
ParseResult SpatGraphCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
SmallVector<Value> weightArgs;
|
return parseComputeLikeOp<SpatGraphCompute>(parser, result);
|
||||||
weightArgs.reserve(getWeights().size());
|
|
||||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
|
||||||
auto weightArg = getWeightArgument(index);
|
|
||||||
if (!weightArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
weightArgs.push_back(*weightArg);
|
|
||||||
}
|
}
|
||||||
SmallVector<Value> inputArgs;
|
void SpatScheduledCompute::print(OpAsmPrinter& printer) { printComputeLikeOp(*this, printer); }
|
||||||
inputArgs.reserve(getInputs().size());
|
ParseResult SpatScheduledCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
return parseComputeLikeOp<SpatScheduledCompute>(parser, result);
|
||||||
auto inputArg = getInputArgument(index);
|
|
||||||
if (!inputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
inputArgs.push_back(*inputArg);
|
|
||||||
}
|
}
|
||||||
|
void SpatGraphComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||||
SmallVector<BlockArgument> outputArgs;
|
ParseResult SpatGraphComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
if (!laneArg)
|
return parseComputeBatchLikeOp<SpatGraphComputeBatch>(parser, result);
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
if (getNumResults() != 0) {
|
|
||||||
outputArgs.reserve(getNumResults());
|
|
||||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
|
||||||
auto outputArg = getOutputArgument(index);
|
|
||||||
if (!outputArg)
|
|
||||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
|
||||||
outputArgs.push_back(*outputArg);
|
|
||||||
}
|
}
|
||||||
}
|
void SpatScheduledComputeBatch::print(OpAsmPrinter& printer) { printComputeBatchLikeOp(*this, printer); }
|
||||||
|
ParseResult SpatScheduledComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
printer << " ";
|
return parseComputeBatchLikeOp<SpatScheduledComputeBatch>(parser, result);
|
||||||
printer.printOperand(*laneArg);
|
|
||||||
printer << " = 0 to " << getLaneCount();
|
|
||||||
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
|
||||||
|
|
||||||
if (getNumResults() != 0) {
|
|
||||||
printer << " shared_outs";
|
|
||||||
printBlockArgumentList(printer, outputArgs);
|
|
||||||
}
|
|
||||||
|
|
||||||
printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size();
|
|
||||||
|
|
||||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
|
||||||
printer << " coreIds ";
|
|
||||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
|
||||||
}
|
|
||||||
|
|
||||||
printer.printOptionalAttrDict(
|
|
||||||
(*this)->getAttrs(),
|
|
||||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
|
||||||
|
|
||||||
printer << " : ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
|
||||||
printer << " ";
|
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
|
||||||
printer << " -> ";
|
|
||||||
printCompressedTypeSequence(printer, getResultTypes());
|
|
||||||
printer << " ";
|
|
||||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
|
||||||
}
|
|
||||||
|
|
||||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
|
||||||
int64_t lowerBound = 0;
|
|
||||||
int32_t laneCount = 0;
|
|
||||||
OpAsmParser::Argument laneArg;
|
|
||||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
|
||||||
SmallVector<Type> weightTypes;
|
|
||||||
SmallVector<Type> inputTypes;
|
|
||||||
SmallVector<Type> outputTypes;
|
|
||||||
int32_t crossbarWeightCount = 0;
|
|
||||||
SmallVector<int32_t> coreIds;
|
|
||||||
|
|
||||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
|
||||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
|
||||||
return failure();
|
|
||||||
if (lowerBound != 0)
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
|
||||||
if (parseBlockArgumentList(parser, outputArgs))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
|
||||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
|
||||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
|
||||||
return failure();
|
|
||||||
(void) crossbarWeightCount;
|
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
|
||||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
|
||||||
|| parseCompressedRepeatedList(
|
|
||||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (weights.size() != weightTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
|
||||||
if (weightArgs.size() != weights.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
|
||||||
if (inputs.size() != inputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
|
||||||
if (inputArgs.size() != inputs.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
|
||||||
if (outputArgs.size() != outputTypes.size())
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"number of shared output bindings and result types must match");
|
|
||||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
|
||||||
"coreIds cannot be specified both positionally and in attr-dict");
|
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
|
||||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
|
||||||
result.addAttribute(
|
|
||||||
"operandSegmentSizes",
|
|
||||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
|
||||||
if (hasCoreIds)
|
|
||||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
|
||||||
|
|
||||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
|
||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(outputTypes);
|
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
|
||||||
applyBatchRegionArgumentTypes(
|
|
||||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
|
||||||
return parser.parseRegion(*body, regionArgs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
template <typename ComputeOpTy>
|
||||||
Block& block = getBody().front();
|
LogicalResult foldComputeLike(ComputeOpTy compute, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
if (!llvm::hasSingleElement(block))
|
if (!llvm::hasSingleElement(block))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
|||||||
for (Value yieldedValue : yieldOp.getOperands()) {
|
for (Value yieldedValue : yieldOp.getOperands()) {
|
||||||
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||||
if (blockArg.getOwner() == &block) {
|
if (blockArg.getOwner() == &block) {
|
||||||
results.push_back(getOperand(blockArg.getArgNumber()));
|
results.push_back(compute.getOperand(blockArg.getArgNumber()));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,5 +32,13 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatGraphCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
return foldComputeLike(*this, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatScheduledCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||||
|
return foldComputeLike(*this, results);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
|||||||
return shapedType.getShape();
|
return shapedType.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
template <typename ComputeBatchOpTy>
|
||||||
|
static bool isBatchOutputArgument(ComputeBatchOpTy batchOp, Value value) {
|
||||||
if (batchOp.getNumResults() == 0)
|
if (batchOp.getNumResults() == 0)
|
||||||
return false;
|
return false;
|
||||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||||
@@ -58,8 +59,28 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind)
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isStaticIndexExpr(Value value) {
|
||||||
|
if (matchConstantIndexValue(value))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
|
if (affineApply) {
|
||||||
|
if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap()))
|
||||||
|
return false;
|
||||||
|
return llvm::all_of(affineApply.getMapOperands(), isStaticIndexExpr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto addOp = value.getDefiningOp<arith::AddIOp>())
|
||||||
|
return isStaticIndexExpr(addOp.getLhs()) && isStaticIndexExpr(addOp.getRhs());
|
||||||
|
|
||||||
|
if (auto mulOp = value.getDefiningOp<arith::MulIOp>())
|
||||||
|
return isStaticIndexExpr(mulOp.getLhs()) && isStaticIndexExpr(mulOp.getRhs());
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||||
if (value == laneArg || matchConstantIndexValue(value))
|
if (value == laneArg || isStaticIndexExpr(value))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
@@ -83,10 +104,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||||
if (!addOp)
|
if (addOp)
|
||||||
|
return (isSupportedLaneOffsetExpr(addOp.getLhs(), laneArg) && isStaticIndexExpr(addOp.getRhs()))
|
||||||
|
|| (isSupportedLaneOffsetExpr(addOp.getRhs(), laneArg) && isStaticIndexExpr(addOp.getLhs()));
|
||||||
|
|
||||||
|
auto mulOp = value.getDefiningOp<arith::MulIOp>();
|
||||||
|
if (!mulOp)
|
||||||
return false;
|
return false;
|
||||||
return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs()))
|
return (isSupportedLaneOffsetExpr(mulOp.getLhs(), laneArg) && isStaticIndexExpr(mulOp.getRhs()))
|
||||||
|| (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs()));
|
|| (isSupportedLaneOffsetExpr(mulOp.getRhs(), laneArg) && isStaticIndexExpr(mulOp.getLhs()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
@@ -158,17 +184,27 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
|||||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
InFlightDiagnostic diagnostic =
|
||||||
<< kind << " body may only directly reference external constants";
|
ownerOp->emitOpError() << kind << " body may not capture external values";
|
||||||
diagnostic.attachNote(op->getLoc())
|
diagnostic.attachNote(op->getLoc())
|
||||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
<< "owner='" << ownerOp->getName() << "' nestedOp='" << op->getName() << "' operand#"
|
||||||
|
<< operand.getOperandNumber() << " type=" << value.getType()
|
||||||
|
<< " category=" << (isa<TensorType>(value.getType()) ? "tensor" : (value.getType().isIndex() ? "index"
|
||||||
|
: "scalar"));
|
||||||
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
|
diagnostic.attachNote(definingOp->getLoc()) << "defining op is '" << definingOp->getName() << "'";
|
||||||
|
else if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
|
diagnostic.attachNote(blockArg.getOwner()->getParentOp()->getLoc())
|
||||||
|
<< "value is block argument #" << blockArg.getArgNumber() << " of '"
|
||||||
|
<< blockArg.getOwner()->getParentOp()->getName() << "'";
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
template <typename ComputeBatchOpTy>
|
||||||
|
static LogicalResult verifyBatchBody(ComputeBatchOpTy batchOp, Block& block) {
|
||||||
if (batchOp.getNumResults() == 0) {
|
if (batchOp.getNumResults() == 0) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
@@ -344,144 +380,406 @@ LogicalResult SpatConcatOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; }
|
||||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
|
||||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
static bool isKnownPhysicalLayout(StringRef layout) {
|
||||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
return layout == "dense_nchw" || layout == "nchw_row_strip" || layout == "fragmented";
|
||||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
}
|
||||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
|
||||||
});
|
static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) {
|
||||||
})) {
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
return op->emitError("ComputeResult used directly inside another Compute");
|
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||||
|
if (!inputType || !outputType)
|
||||||
|
return op->emitOpError() << kind << " requires ranked tensor input and output types";
|
||||||
|
if (inputType.getElementType() != outputType.getElementType())
|
||||||
|
return op->emitOpError() << kind << " requires matching input/output element types";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatConv2DPlanOp::verify() {
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(getInput().getType());
|
||||||
|
auto weightType = dyn_cast<RankedTensorType>(getWeight().getType());
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||||
|
if (!inputType || !weightType || !outputType)
|
||||||
|
return emitError("requires ranked tensor input, weight, and output");
|
||||||
|
if (inputType.getRank() != 4 || weightType.getRank() != 4 || outputType.getRank() != 4)
|
||||||
|
return emitError("requires rank-4 input, weight, and output tensors");
|
||||||
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (getPads().size() != 4)
|
||||||
|
return emitError("requires exactly four pad values");
|
||||||
|
if (getStrides().size() != 2)
|
||||||
|
return emitError("requires exactly two stride values");
|
||||||
|
if (getDilations().size() != 2)
|
||||||
|
return emitError("requires exactly two dilation values");
|
||||||
|
if (getGroup() < 1)
|
||||||
|
return emitError("requires group >= 1");
|
||||||
|
if (inputType.getElementType() != weightType.getElementType()
|
||||||
|
|| inputType.getElementType() != outputType.getElementType()) {
|
||||||
|
return emitError("requires matching input, weight, and output element types");
|
||||||
|
}
|
||||||
|
if (getBias()) {
|
||||||
|
auto biasType = dyn_cast<RankedTensorType>(getBias().getType());
|
||||||
|
if (!biasType)
|
||||||
|
return emitError("requires ranked tensor bias type");
|
||||||
|
if (biasType.getElementType() != outputType.getElementType())
|
||||||
|
return emitError("requires bias element type to match output element type");
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatReluPlanOp::verify() {
|
||||||
auto& block = getBody().front();
|
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.relu_plan")))
|
||||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
return failure();
|
||||||
if (block.getNumArguments() != expectedArgCount)
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
return emitError("compute body must have weight and input block arguments");
|
return emitError("requires a known logical layout");
|
||||||
|
return success();
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
|
||||||
auto blockArg = getWeightArgument(weightIndex);
|
|
||||||
if (!blockArg || blockArg->getType() != weight.getType())
|
|
||||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
|
||||||
}
|
}
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
||||||
auto blockArg = getInputArgument(inputIndex);
|
LogicalResult SpatBlueprintOp::verify() {
|
||||||
|
auto modeAttr = getModeAttr();
|
||||||
|
bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly";
|
||||||
|
if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.blueprint")))
|
||||||
|
return failure();
|
||||||
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getPhysicalLayout()))
|
||||||
|
return emitError("requires a known physical layout");
|
||||||
|
|
||||||
|
auto logicalType = dyn_cast<RankedTensorType>(getOutput().getType());
|
||||||
|
if (!logicalType)
|
||||||
|
return emitError("requires ranked tensor output");
|
||||||
|
|
||||||
|
auto offsets = getFragmentOffsets();
|
||||||
|
auto sizes = getFragmentSizes();
|
||||||
|
if (offsets.size() != sizes.size())
|
||||||
|
return emitError("fragment offset and size arrays must have the same length");
|
||||||
|
int64_t rank = logicalType.getRank();
|
||||||
|
if (offsets.empty())
|
||||||
|
return success();
|
||||||
|
if (rank <= 0 || offsets.size() % rank != 0)
|
||||||
|
return emitError("fragment metadata must be a whole number of rank-sized fragments");
|
||||||
|
|
||||||
|
auto verifyBoundsOnly = [&](ArrayRef<int64_t> strideValues) -> LogicalResult {
|
||||||
|
ArrayRef<int64_t> shape = logicalType.getShape();
|
||||||
|
for (int64_t index = 0; index < static_cast<int64_t>(offsets.size()); ++index) {
|
||||||
|
int64_t dim = index % rank;
|
||||||
|
int64_t offset = offsets[index];
|
||||||
|
int64_t size = sizes[index];
|
||||||
|
int64_t stride = strideValues.empty() ? 1 : strideValues[index];
|
||||||
|
if (offset < 0 || size < 0 || stride < 0)
|
||||||
|
return emitError("fragment offsets, sizes, and strides must be non-negative");
|
||||||
|
int64_t logicalDim = shape[dim];
|
||||||
|
if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim)
|
||||||
|
return emitError("fragment bounds must stay within the logical tensor shape");
|
||||||
|
if (stride != 1)
|
||||||
|
return emitError("fragment assembly currently requires unit strides");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!isFragmentAssembly) {
|
||||||
|
if (failed(verifyBoundsOnly({})))
|
||||||
|
return failure();
|
||||||
|
if (!getFragments().empty())
|
||||||
|
return emitError("legacy blueprint does not accept extra fragment operands");
|
||||||
|
if (getFragmentSourceOffsetsAttr() || getFragmentStridesAttr() || getConflictPolicyAttr()
|
||||||
|
|| getCoveragePolicyAttr())
|
||||||
|
return emitError("legacy blueprint does not accept fragment assembly attributes");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stridesAttr = getFragmentStridesAttr();
|
||||||
|
auto operandIndicesAttr = getFragmentOperandIndicesAttr();
|
||||||
|
auto sourceOffsetsAttr = getFragmentSourceOffsetsAttr();
|
||||||
|
if (!operandIndicesAttr)
|
||||||
|
return emitError("fragment assembly blueprint requires fragment operand indices");
|
||||||
|
if (!sourceOffsetsAttr)
|
||||||
|
return emitError("fragment assembly blueprint requires fragment source offsets");
|
||||||
|
if (!stridesAttr)
|
||||||
|
return emitError("fragment assembly blueprint requires fragment strides");
|
||||||
|
ArrayRef<int64_t> operandIndices = operandIndicesAttr.asArrayRef();
|
||||||
|
ArrayRef<int64_t> sourceOffsets = sourceOffsetsAttr.asArrayRef();
|
||||||
|
ArrayRef<int64_t> strides = stridesAttr.asArrayRef();
|
||||||
|
if (strides.size() != offsets.size())
|
||||||
|
return emitError("fragment stride and offset arrays must have the same length");
|
||||||
|
if (sourceOffsets.size() != operandIndices.size())
|
||||||
|
return emitError("fragment source offset count must match fragment operand index count");
|
||||||
|
if (!getConflictPolicyAttr() || !getCoveragePolicyAttr())
|
||||||
|
return emitError("fragment assembly blueprint requires conflict and coverage policies");
|
||||||
|
if (getConflictPolicy() != "disjoint")
|
||||||
|
return emitError("fragment assembly blueprint currently supports only conflict_policy=\"disjoint\"");
|
||||||
|
if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial")
|
||||||
|
return emitError("fragment assembly blueprint coverage_policy must be \"complete\" or \"partial\"");
|
||||||
|
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
operands.push_back(getInput());
|
||||||
|
llvm::append_range(operands, getFragments());
|
||||||
|
int64_t operandCount = static_cast<int64_t>(operands.size());
|
||||||
|
int64_t fragmentCount = static_cast<int64_t>(operandIndices.size());
|
||||||
|
if (operandCount == 0)
|
||||||
|
return emitError("fragment assembly blueprint requires at least one operand");
|
||||||
|
if (static_cast<int64_t>(offsets.size()) != fragmentCount * rank)
|
||||||
|
return emitError("fragment assembly metadata count must match operand count * result rank");
|
||||||
|
if (failed(verifyBoundsOnly(strides)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<std::pair<SmallVector<int64_t, 4>, SmallVector<int64_t, 4>>, 8> slices;
|
||||||
|
slices.reserve(static_cast<size_t>(fragmentCount));
|
||||||
|
SmallVector<int64_t, 8> fragmentCountsByOperand(static_cast<size_t>(operandCount), 0);
|
||||||
|
auto expandFlatElementIndex = [](int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||||
|
indices[dim] = flatIndex % shape[dim];
|
||||||
|
flatIndex /= shape[dim];
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
};
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
if (operandIndex < 0 || operandIndex >= operandCount)
|
||||||
|
return emitError("fragment assembly operand index is out of range");
|
||||||
|
if (sourceOffsets[fragmentIndex] < 0)
|
||||||
|
return emitError("fragment assembly source offsets must be nonnegative");
|
||||||
|
|
||||||
|
auto operandType = dyn_cast<RankedTensorType>(operands[operandIndex].getType());
|
||||||
|
if (!operandType || !operandType.hasStaticShape())
|
||||||
|
return emitError("fragment assembly blueprint requires static ranked tensor operands");
|
||||||
|
if (operandType.getRank() != rank)
|
||||||
|
return emitError("fragment assembly blueprint requires operand/result rank match");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
SmallVector<int64_t, 4> fragmentSizes;
|
||||||
|
fragmentOffsets.reserve(rank);
|
||||||
|
fragmentSizes.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
fragmentOffsets.push_back(offsets[flatIndex]);
|
||||||
|
fragmentSizes.push_back(sizes[flatIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
++fragmentCountsByOperand[static_cast<size_t>(operandIndex)];
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
fragmentElements *= fragmentSizes[dim];
|
||||||
|
if (sourceOffsets[fragmentIndex] + fragmentElements > operandType.getNumElements())
|
||||||
|
return emitError("fragment assembly source offset exceeds the operand bounds");
|
||||||
|
SmallVector<int64_t, 4> sourceSliceOffsets =
|
||||||
|
expandFlatElementIndex(sourceOffsets[fragmentIndex], operandType.getShape());
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
if (sourceSliceOffsets[dim] + fragmentSizes[dim] > operandType.getDimSize(dim))
|
||||||
|
return emitError("fragment assembly source offset must describe a valid unit-stride slice");
|
||||||
|
|
||||||
|
for (const auto& [existingOffsets, existingSizes] : slices) {
|
||||||
|
bool overlaps = true;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t begin = fragmentOffsets[dim];
|
||||||
|
int64_t end = begin + fragmentSizes[dim];
|
||||||
|
int64_t existingBegin = existingOffsets[dim];
|
||||||
|
int64_t existingEnd = existingBegin + existingSizes[dim];
|
||||||
|
if (end <= existingBegin || existingEnd <= begin) {
|
||||||
|
overlaps = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (overlaps)
|
||||||
|
return emitError("fragment assembly blueprint requires disjoint static slices");
|
||||||
|
}
|
||||||
|
slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) {
|
||||||
|
if (fragmentCountsByOperand[static_cast<size_t>(operandIndex)] == 0)
|
||||||
|
return emitError("fragment assembly blueprint requires every operand to contribute at least one fragment");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getCoveragePolicy() == "complete") {
|
||||||
|
int64_t covered = 0;
|
||||||
|
int64_t logicalElements = 1;
|
||||||
|
for (int64_t dimSize : logicalType.getShape()) {
|
||||||
|
if (ShapedType::isDynamic(dimSize))
|
||||||
|
return emitError("fragment assembly complete coverage requires static result shape");
|
||||||
|
logicalElements *= dimSize;
|
||||||
|
}
|
||||||
|
for (const auto& [ignoredOffsets, fragmentSizes] : slices) {
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dimSize : fragmentSizes)
|
||||||
|
fragmentElements *= dimSize;
|
||||||
|
covered += fragmentElements;
|
||||||
|
}
|
||||||
|
if (covered != logicalElements)
|
||||||
|
return emitError("fragment assembly complete coverage must cover the whole result exactly");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatMaterializeLayoutOp::verify() {
|
||||||
|
if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.materialize_layout")))
|
||||||
|
return failure();
|
||||||
|
if (!isKnownLogicalLayout(getLogicalLayout()))
|
||||||
|
return emitError("requires a known logical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getSourcePhysicalLayout()))
|
||||||
|
return emitError("requires a known source physical layout");
|
||||||
|
if (!isKnownPhysicalLayout(getTargetPhysicalLayout()))
|
||||||
|
return emitError("requires a known target physical layout");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||||
|
if (!isAnySpatialComputeLike(op))
|
||||||
|
return op->emitError("verifyComputeResultUses: op is not a Spatial compute-like operation");
|
||||||
|
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||||
|
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||||
|
return !isAnySpatialComputeLike(op->getParentOp());
|
||||||
|
});
|
||||||
|
})) {
|
||||||
|
return op->emitError("compute result used directly inside another Spatial compute body");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
LogicalResult verifyComputeLikeOp(ComputeOpTy compute, StringRef opName) {
|
||||||
|
auto& block = compute.getBody().front();
|
||||||
|
unsigned expectedArgCount = compute.getWeights().size() + compute.getInputs().size();
|
||||||
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
|
return compute.emitOpError("compute body must have weight and input block arguments");
|
||||||
|
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||||
|
auto blockArg = compute.getWeightArgument(weightIndex);
|
||||||
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
|
return compute.emitOpError("compute weight block argument types must match weight operand types exactly");
|
||||||
|
}
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
auto blockArg = compute.getInputArgument(inputIndex);
|
||||||
if (!blockArg || blockArg->getType() != input.getType())
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
return emitError("compute input block argument types must match input operand types exactly");
|
return compute.emitOpError("compute input block argument types must match input operand types exactly");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
return emitError("ComputeOp must have a single yield operation");
|
return compute.emitOpError("ComputeOp must have a single yield operation");
|
||||||
|
|
||||||
auto resultTypes = getResultTypes();
|
auto resultTypes = compute.getResultTypes();
|
||||||
auto yieldTypes = yieldOp->getOperandTypes();
|
auto yieldTypes = yieldOp->getOperandTypes();
|
||||||
if (resultTypes.size() != yieldTypes.size())
|
if (resultTypes.size() != yieldTypes.size())
|
||||||
return emitError("ComputeOp must have same number of results as yieldOp operands");
|
return compute.emitOpError("ComputeOp must have same number of results as yieldOp operands");
|
||||||
|
|
||||||
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
||||||
auto resultType = std::get<0>(it);
|
auto resultType = std::get<0>(it);
|
||||||
auto yieldType = std::get<1>(it);
|
auto yieldType = std::get<1>(it);
|
||||||
|
|
||||||
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
||||||
return emitError("ComputeOp output must be of the same type as yieldOp operand");
|
return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand");
|
||||||
|
|
||||||
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
||||||
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
||||||
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
||||||
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
|
return compute.emitOpError("ComputeOp output must have the same encoding as yieldOp operand");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
return compute.emitOpError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
||||||
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
return compute.emitOpError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
for (unsigned inputIndex = 0; inputIndex < compute.getInputs().size(); ++inputIndex)
|
||||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
if (auto inputArg = compute.getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return compute.emitOpError("ComputeOp block argument is not used");
|
||||||
if (failed(verifyStaticWeights(*this, "compute")))
|
if (failed(verifyStaticWeights(compute, opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
if (failed(verifyOnlyConstantExternalValues(compute.getOperation(), compute.getBody(), opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(compute.getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatComputeBatch::verify() {
|
LogicalResult SpatGraphCompute::verify() { return verifyComputeLikeOp(*this, "spat.graph_compute"); }
|
||||||
int32_t count = getLaneCount();
|
|
||||||
|
LogicalResult SpatScheduledCompute::verify() { return verifyComputeLikeOp(*this, "spat.scheduled_compute"); }
|
||||||
|
|
||||||
|
template <typename ComputeBatchOpTy>
|
||||||
|
LogicalResult verifyComputeBatchLikeOp(ComputeBatchOpTy batch, StringRef opName) {
|
||||||
|
int32_t count = batch.getLaneCount();
|
||||||
if (count <= 0)
|
if (count <= 0)
|
||||||
return emitError("laneCount must be positive");
|
return batch.emitOpError("laneCount must be positive");
|
||||||
|
|
||||||
auto laneCountSz = static_cast<size_t>(count);
|
auto laneCountSz = static_cast<size_t>(count);
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
if (auto coreIdAttr = batch->getAttr(kCoreIdsAttrName)) {
|
||||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||||
if (!coreIdsAttr)
|
if (!coreIdsAttr)
|
||||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
return batch.emitOpError("compute_batch coreIds attribute must be a dense i32 array");
|
||||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||||
return emitError("compute_batch coreIds array length must match laneCount");
|
return batch.emitOpError("compute_batch coreIds array length must match laneCount");
|
||||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||||
return emitError("compute_batch coreIds values must be non-negative");
|
return batch.emitOpError("compute_batch coreIds values must be non-negative");
|
||||||
DenseSet<int32_t> seenCoreIds;
|
DenseSet<int32_t> seenCoreIds;
|
||||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||||
if (!seenCoreIds.insert(coreId).second)
|
if (!seenCoreIds.insert(coreId).second)
|
||||||
return emitError("compute_batch coreIds values must be unique");
|
return batch.emitOpError("compute_batch coreIds values must be unique");
|
||||||
}
|
}
|
||||||
|
|
||||||
Block& block = getBody().front();
|
Block& block = batch.getBody().front();
|
||||||
if (block.getNumArguments() == 0)
|
if (block.getNumArguments() == 0)
|
||||||
return emitError("compute_batch body must have exactly one lane block argument");
|
return batch.emitOpError("compute_batch body must have exactly one lane block argument");
|
||||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
unsigned expectedArgCount = 1 + batch.getWeights().size() + batch.getInputs().size() + batch.getNumResults();
|
||||||
if (block.getNumArguments() != expectedArgCount)
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
return batch.emitOpError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||||
auto laneArg = getLaneArgument();
|
auto laneArg = batch.getLaneArgument();
|
||||||
if (!laneArg || !laneArg->getType().isIndex())
|
if (!laneArg || !laneArg->getType().isIndex())
|
||||||
return emitError("compute_batch first block argument must have index type");
|
return batch.emitOpError("compute_batch first block argument must have index type");
|
||||||
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(batch.getWeights())) {
|
||||||
auto blockArg = getWeightArgument(weightIndex);
|
auto blockArg = batch.getWeightArgument(weightIndex);
|
||||||
if (!blockArg || blockArg->getType() != weight.getType())
|
if (!blockArg || blockArg->getType() != weight.getType())
|
||||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
return batch.emitOpError("compute_batch weight block argument types must match weight operand types exactly");
|
||||||
}
|
}
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) {
|
||||||
auto blockArg = getInputArgument(inputIndex);
|
auto blockArg = batch.getInputArgument(inputIndex);
|
||||||
if (!blockArg || blockArg->getType() != input.getType())
|
if (!blockArg || blockArg->getType() != input.getType())
|
||||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
return batch.emitOpError("compute_batch input block argument types must match input operand types exactly");
|
||||||
}
|
}
|
||||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
for (auto [resultIndex, resultType] : llvm::enumerate(batch.getResultTypes())) {
|
||||||
auto blockArg = getOutputArgument(resultIndex);
|
auto blockArg = batch.getOutputArgument(resultIndex);
|
||||||
if (!blockArg || blockArg->getType() != resultType)
|
if (!blockArg || blockArg->getType() != resultType)
|
||||||
return emitError("compute_batch output block argument types must match result types exactly");
|
return batch.emitOpError("compute_batch output block argument types must match result types exactly");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(batch.getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyStaticWeights(*this, "compute_batch")))
|
if (failed(verifyStaticWeights(batch, opName)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
if (failed(verifyOnlyConstantExternalValues(batch.getOperation(), batch.getBody(), opName)))
|
||||||
return failure();
|
return failure();
|
||||||
return verifyBatchBody(*this, block);
|
return verifyBatchBody(batch, block);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatGraphComputeBatch::verify() { return verifyComputeBatchLikeOp(*this, "spat.graph_compute_batch"); }
|
||||||
|
|
||||||
|
LogicalResult SpatScheduledComputeBatch::verify() {
|
||||||
|
return verifyComputeBatchLikeOp(*this, "spat.scheduled_compute_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatInParallelOp::verify() {
|
LogicalResult SpatInParallelOp::verify() {
|
||||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
Operation* parent = getOperation()->getParentOp();
|
||||||
if (!batchOp)
|
if (!isAnySpatialComputeBatchLike(parent))
|
||||||
return emitOpError("expected spat.compute_batch parent");
|
return emitOpError("expected spat.graph_compute_batch or spat.scheduled_compute_batch parent");
|
||||||
if (batchOp.getNumResults() == 0)
|
if (parent->getNumResults() == 0)
|
||||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||||
|
|
||||||
auto laneArg = batchOp.getLaneArgument();
|
std::optional<BlockArgument> laneArg;
|
||||||
|
if (auto graphBatch = dyn_cast<SpatGraphComputeBatch>(parent))
|
||||||
|
laneArg = graphBatch.getLaneArgument();
|
||||||
|
else
|
||||||
|
laneArg = cast<SpatScheduledComputeBatch>(parent).getLaneArgument();
|
||||||
if (!laneArg)
|
if (!laneArg)
|
||||||
return emitOpError("expected compute_batch lane block argument");
|
return emitOpError("expected compute_batch lane block argument");
|
||||||
for (Operation& op : getRegion().front().getOperations()) {
|
for (Operation& op : getRegion().front().getOperations()) {
|
||||||
@@ -494,7 +792,10 @@ LogicalResult SpatInParallelOp::verify() {
|
|||||||
|
|
||||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||||
for (OpOperand& destination : destinations)
|
for (OpOperand& destination : destinations)
|
||||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
if ((isa<SpatGraphComputeBatch>(parent)
|
||||||
|
&& !isBatchOutputArgument(cast<SpatGraphComputeBatch>(parent), destination.get()))
|
||||||
|
|| (isa<SpatScheduledComputeBatch>(parent)
|
||||||
|
&& !isBatchOutputArgument(cast<SpatScheduledComputeBatch>(parent), destination.get())))
|
||||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+2819
-169
File diff suppressed because it is too large
Load Diff
@@ -40,11 +40,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
using SpatCompute = spatial::SpatCompute;
|
using SpatCompute = spatial::SpatGraphCompute;
|
||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
using SpatComputeBatch = spatial::SpatGraphComputeBatch;
|
||||||
using spatial::getProducerValueRef;
|
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(spatial::SpatScheduledCompute compute) {
|
||||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||||
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
auto checkedCoreId = pim::checkedI32(coreIdAttr.getInt(), compute, "merge compute core id");
|
||||||
if (failed(checkedCoreId))
|
if (failed(checkedCoreId))
|
||||||
@@ -187,32 +186,50 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//TODO Used for report refactor
|
||||||
|
struct CollectorConcatRow {
|
||||||
|
uint64_t computeId = 0;
|
||||||
|
int32_t coreId = -1;
|
||||||
|
uint64_t operandCount = 0;
|
||||||
|
};
|
||||||
|
|
||||||
uint64_t totalComputeOps = 0;
|
uint64_t totalComputeOps = 0;
|
||||||
uint64_t totalLogicalComputes = 0;
|
uint64_t totalLogicalComputes = 0;
|
||||||
uint64_t totalBatchComputeOps = 0;
|
uint64_t totalBatchComputeOps = 0;
|
||||||
uint64_t totalInstructionCount = 0;
|
uint64_t totalInstructionCount = 0;
|
||||||
uint64_t totalCrossbarCount = 0;
|
uint64_t totalCrossbarCount = 0;
|
||||||
uint64_t nextBatchId = 0;
|
uint64_t nextBatchId = 0;
|
||||||
|
//TODO Used for report refactor
|
||||||
std::vector<ReportRow> collectedData;
|
std::vector<ReportRow> collectedData;
|
||||||
|
//TODO Used for report refactor
|
||||||
|
std::vector<CollectorConcatRow> collectorConcatRows;
|
||||||
|
|
||||||
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||||
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||||
};
|
};
|
||||||
|
|
||||||
for (Operation& op : funcOp.getBody().front()) {
|
for (Operation& op : funcOp.getBody().front()) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
if (auto spatCompute = dyn_cast<spatial::SpatScheduledCompute>(&op)) {
|
||||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
if (auto coreId = getComputeCoreId(spatCompute))
|
if (auto coreId = getComputeCoreId(spatCompute))
|
||||||
coreIds.push_back(*coreId);
|
coreIds.push_back(*coreId);
|
||||||
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
uint64_t computeId = totalComputeOps++;
|
||||||
|
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||||
|
uint64_t maxConcatOperands = 0;
|
||||||
|
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
|
||||||
|
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
|
||||||
|
});
|
||||||
|
//TODO 128 is a magic number
|
||||||
|
if (maxConcatOperands >= 128 && !coreIds.empty())
|
||||||
|
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
|
||||||
totalLogicalComputes += 1;
|
totalLogicalComputes += 1;
|
||||||
totalInstructionCount += numInst;
|
totalInstructionCount += numInst;
|
||||||
totalCrossbarCount += perInstanceCrossbarCount;
|
totalCrossbarCount += perInstanceCrossbarCount;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
if (auto batch = dyn_cast<spatial::SpatScheduledComputeBatch>(&op)) {
|
||||||
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
||||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||||
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||||
};
|
};
|
||||||
printReportTotalsBlock(os, totalFields);
|
printReportTotalsBlock(os, totalFields);
|
||||||
if (!collectedData.empty())
|
if (!collectedData.empty() || !collectorConcatRows.empty())
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
if (!collectorConcatRows.empty()) {
|
||||||
|
os << "Collector concat materialization:\n";
|
||||||
|
for (const CollectorConcatRow& row : collectorConcatRows)
|
||||||
|
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
|
||||||
|
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
sortReportEntriesByFirstCore(collectedData);
|
sortReportEntriesByFirstCore(collectedData);
|
||||||
|
|
||||||
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
||||||
@@ -328,7 +353,17 @@ public:
|
|||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
|
func.emitOpError("logical Spatial graph verification failed at the start of MergeComputeNodes");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
mergeTriviallyConnectedComputes(func);
|
mergeTriviallyConnectedComputes(func);
|
||||||
|
if (failed(verifyLogicalSpatialGraphInvariants(func))) {
|
||||||
|
func.emitOpError("logical Spatial graph verification failed after trivial merge simplification");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
const spatial::MergeScheduleResult* analysisResult = nullptr;
|
||||||
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
analysisResult = &getAnalysis<spatial::MergeSchedulingAnalysis>().getResult();
|
||||||
@@ -342,8 +377,8 @@ public:
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (failed(verifySpatialCommunicationInvariants(func))) {
|
if (failed(verifyScheduledSpatialInvariants(func))) {
|
||||||
func.emitOpError("merged Spatial communication invariant verification failed");
|
func.emitOpError("scheduled Spatial verification failed after merge materialization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -21,6 +22,7 @@
|
|||||||
|
|
||||||
#include "ComputeGraph.hpp"
|
#include "ComputeGraph.hpp"
|
||||||
#include "ComputeInstanceUtils.hpp"
|
#include "ComputeInstanceUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
@@ -35,9 +37,223 @@ uint64_t countComputeBodyOperationInstances(Region& body);
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Cost getComputeBodyCost(Region& body) {
|
struct PimsimSchedulerCostModel {
|
||||||
constexpr Cost kOperationCost = 100;
|
static constexpr Cost kDefaultBitwidth = 8;
|
||||||
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
|
static constexpr Cost kCorePeriodNs = 1;
|
||||||
|
static constexpr Cost kLocalMemoryWidthBytes = 64;
|
||||||
|
static constexpr Cost kLocalMemoryLatencyCycles = 1;
|
||||||
|
static constexpr Cost kNetworkBusWidthBytes = 8;
|
||||||
|
static constexpr Cost kNetworkBaseLatencyNs = 2;
|
||||||
|
static constexpr Cost kNetworkPerHopLatencyNs = 1;
|
||||||
|
static constexpr Cost kVectorWidth = 16;
|
||||||
|
static constexpr Cost kVectorLatencyCycles = 4;
|
||||||
|
static constexpr Cost kDacResolutionBits = 1;
|
||||||
|
static constexpr Cost kDacLatencyCycles = 1;
|
||||||
|
static constexpr Cost kDacCount = 128;
|
||||||
|
static constexpr Cost kXbarReadLatencyNs = 30;
|
||||||
|
static constexpr Cost kSampleHoldLatencyCycles = 1;
|
||||||
|
static constexpr Cost kAdcLatencyCycles = 10;
|
||||||
|
static constexpr Cost kAdcCount = 2;
|
||||||
|
static constexpr Cost kShiftAdderLatencyCycles = 1;
|
||||||
|
static constexpr Cost kOutputBufferLatencyCycles = 1;
|
||||||
|
static constexpr Cost kInputBufferLatencyCycles = 0;
|
||||||
|
static constexpr Cost kFallbackOperationCost = 1;
|
||||||
|
|
||||||
|
static Cost ceilDiv(Cost numerator, Cost denominator) {
|
||||||
|
assert(denominator > 0 && "denominator must be positive");
|
||||||
|
return (numerator + denominator - 1) / denominator;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<Cost> getStaticElementCount(Type type) {
|
||||||
|
auto shaped = dyn_cast<ShapedType>(type);
|
||||||
|
if (!shaped || !shaped.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
return static_cast<Cost>(shaped.getNumElements());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getBitwidthOrDefault(Type type) {
|
||||||
|
if (auto shaped = dyn_cast<ShapedType>(type))
|
||||||
|
type = shaped.getElementType();
|
||||||
|
if (auto intType = dyn_cast<IntegerType>(type))
|
||||||
|
return intType.getWidth();
|
||||||
|
if (auto floatType = dyn_cast<FloatType>(type))
|
||||||
|
return floatType.getWidth();
|
||||||
|
if (isa<IndexType>(type))
|
||||||
|
return 64;
|
||||||
|
return kDefaultBitwidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getComputeBitwidth(Type type) {
|
||||||
|
return std::min(getBitwidthOrDefault(type), kDefaultBitwidth);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getByteSize(Type type, Cost fallbackBitwidth = kDefaultBitwidth) {
|
||||||
|
auto elementCount = getStaticElementCount(type);
|
||||||
|
if (!elementCount)
|
||||||
|
return kFallbackOperationCost;
|
||||||
|
Cost bitwidth = fallbackBitwidth;
|
||||||
|
if (bitwidth <= 0)
|
||||||
|
bitwidth = getBitwidthOrDefault(type);
|
||||||
|
if (bitwidth <= 0)
|
||||||
|
bitwidth = kDefaultBitwidth;
|
||||||
|
return ceilDiv(checkedMultiply(*elementCount, bitwidth), static_cast<Cost>(8));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getVectorReadWriteCost(Cost readBytes, Cost writeBytes) {
|
||||||
|
Cost totalBytes = checkedAdd(readBytes, writeBytes);
|
||||||
|
return checkedMultiply(ceilDiv(totalBytes, kLocalMemoryWidthBytes),
|
||||||
|
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getVectorComputeCost(Cost elementCount) {
|
||||||
|
return checkedMultiply(ceilDiv(elementCount, kVectorWidth),
|
||||||
|
checkedMultiply(kVectorLatencyCycles, kCorePeriodNs));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getTensorMoveCost(Type type) {
|
||||||
|
return getVectorReadWriteCost(getByteSize(type), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<Cost, Cost> estimateMeshShape() {
|
||||||
|
Cost coreCount = static_cast<Cost>(std::max<long>(1, coresCount.getValue()));
|
||||||
|
Cost rows = static_cast<Cost>(std::sqrt(static_cast<long double>(coreCount)));
|
||||||
|
if (rows == 0)
|
||||||
|
rows = 1;
|
||||||
|
while (rows > 1 && coreCount % rows != 0)
|
||||||
|
--rows;
|
||||||
|
Cost cols = ceilDiv(coreCount, rows);
|
||||||
|
return {rows, cols};
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getAverageInterCoreLatencyNs() {
|
||||||
|
auto [rows, cols] = estimateMeshShape();
|
||||||
|
auto averageAxisDistance = [](Cost size) -> Cost {
|
||||||
|
if (size <= 1)
|
||||||
|
return 0;
|
||||||
|
return checkedMultiply(size, size) - 1;
|
||||||
|
};
|
||||||
|
Cost avgRow = averageAxisDistance(rows) / (static_cast<Cost>(3) * rows);
|
||||||
|
Cost avgCol = averageAxisDistance(cols) / (static_cast<Cost>(3) * cols);
|
||||||
|
return checkedAdd(kNetworkBaseLatencyNs, checkedMultiply(kNetworkPerHopLatencyNs, checkedAdd(avgRow, avgCol)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getInterCoreTransferCostFromBytes(Cost bytes) {
|
||||||
|
Cost localRead = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
|
||||||
|
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
|
||||||
|
Cost localWrite = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
|
||||||
|
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
|
||||||
|
Cost payloadFlits = ceilDiv(bytes, kNetworkBusWidthBytes);
|
||||||
|
Cost averageNoCLatency = getAverageInterCoreLatencyNs();
|
||||||
|
Cost network = checkedMultiply(checkedAdd(static_cast<Cost>(2), payloadFlits), averageNoCLatency);
|
||||||
|
return checkedAdd(checkedAdd(localRead, localWrite), network);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getUnaryVectorCost(Type inputType, Type outputType, bool scalarOutput = false) {
|
||||||
|
auto maybeElements = getStaticElementCount(inputType);
|
||||||
|
if (!maybeElements)
|
||||||
|
return kFallbackOperationCost;
|
||||||
|
Cost inputBytes = getByteSize(inputType, getComputeBitwidth(inputType));
|
||||||
|
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
|
||||||
|
: getByteSize(outputType, getComputeBitwidth(outputType));
|
||||||
|
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getVectorComputeCost(*maybeElements));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getBinaryVectorCost(Type lhsType, Type rhsType, Type outputType, bool scalarOutput = false) {
|
||||||
|
auto maybeElements = getStaticElementCount(lhsType);
|
||||||
|
if (!maybeElements)
|
||||||
|
return kFallbackOperationCost;
|
||||||
|
Cost readBytes = checkedAdd(getByteSize(lhsType, getComputeBitwidth(lhsType)),
|
||||||
|
getByteSize(rhsType, getComputeBitwidth(rhsType)));
|
||||||
|
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
|
||||||
|
: getByteSize(outputType, getComputeBitwidth(outputType));
|
||||||
|
return checkedAdd(getVectorReadWriteCost(readBytes, outputBytes), getVectorComputeCost(*maybeElements));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getMatrixComputeLatency(Cost inputBitwidth) {
|
||||||
|
Cost xbarDim = static_cast<Cost>(crossbarSize.getValue());
|
||||||
|
Cost inputTimes = ceilDiv(inputBitwidth, kDacResolutionBits);
|
||||||
|
Cost dacTimes = ceilDiv(xbarDim, kDacCount);
|
||||||
|
Cost adcTimes = ceilDiv(xbarDim, kAdcCount);
|
||||||
|
Cost frontStage = kInputBufferLatencyCycles + kDacLatencyCycles + kXbarReadLatencyNs + kSampleHoldLatencyCycles;
|
||||||
|
Cost backPipe = std::max(kAdcLatencyCycles, checkedAdd(kShiftAdderLatencyCycles, kOutputBufferLatencyCycles));
|
||||||
|
Cost backStage = checkedAdd(checkedAdd(kAdcLatencyCycles, kShiftAdderLatencyCycles), kOutputBufferLatencyCycles);
|
||||||
|
backStage = checkedAdd(backStage, checkedMultiply(adcTimes - 1, backPipe));
|
||||||
|
Cost totalTimes = checkedMultiply(inputTimes, dacTimes);
|
||||||
|
Cost stagePipe = std::max(frontStage, backStage);
|
||||||
|
return checkedAdd(checkedAdd(frontStage, backStage),
|
||||||
|
checkedMultiply(totalTimes - 1, stagePipe));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Cost getWvmmCost(Type inputType, Type outputType) {
|
||||||
|
Cost inputBitwidth = getComputeBitwidth(inputType);
|
||||||
|
Cost inputBytes = checkedMultiply(static_cast<Cost>(crossbarSize.getValue()),
|
||||||
|
ceilDiv(inputBitwidth, static_cast<Cost>(8)));
|
||||||
|
inputBytes = checkedMultiply(inputBytes, static_cast<Cost>(8));
|
||||||
|
Cost outputBytes = getByteSize(outputType, getComputeBitwidth(outputType));
|
||||||
|
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getMatrixComputeLatency(inputBitwidth));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop);
|
||||||
|
[[maybe_unused]] Cost getOperationCost(Operation& op);
|
||||||
|
|
||||||
|
[[maybe_unused]] Cost getRegionCost(Region& body) {
|
||||||
|
Cost cost = 0;
|
||||||
|
for (Block& block : body)
|
||||||
|
for (Operation& op : block)
|
||||||
|
cost = checkedAdd(cost, getOperationCost(op));
|
||||||
|
return cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[maybe_unused]] Cost getOperationCost(Operation& op) {
|
||||||
|
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
|
||||||
|
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
|
||||||
|
if (!tripCount)
|
||||||
|
return PimsimSchedulerCostModel::kFallbackOperationCost;
|
||||||
|
return checkedMultiply(getRegionCost(loop.getRegion()), static_cast<Cost>(*tripCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<SpatYieldOp, SpatInParallelOp, affine::AffineApplyOp, arith::ConstantOp,
|
||||||
|
tensor::EmptyOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(&op))
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
if (auto wvmm = dyn_cast<SpatVMMOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getWvmmCost(wvmm.getInput().getType(), wvmm.getOutput().getType());
|
||||||
|
if (auto vvdmul = dyn_cast<SpatVVDMulOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getBinaryVectorCost(
|
||||||
|
vvdmul.getLhs().getType(), vvdmul.getRhs().getType(), vvdmul.getOutput().getType(), /*scalarOutput=*/true);
|
||||||
|
if (auto vadd = dyn_cast<SpatVAddOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getBinaryVectorCost(vadd.getLhs().getType(), vadd.getRhs().getType(),
|
||||||
|
vadd.getOutput().getType());
|
||||||
|
if (auto vsub = dyn_cast<SpatVSubOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getBinaryVectorCost(vsub.getLhs().getType(), vsub.getRhs().getType(),
|
||||||
|
vsub.getOutput().getType());
|
||||||
|
if (auto vmul = dyn_cast<SpatVMulOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getBinaryVectorCost(vmul.getLhs().getType(), vmul.getRhs().getType(),
|
||||||
|
vmul.getOutput().getType());
|
||||||
|
if (auto vmax = dyn_cast<SpatVMaxOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getBinaryVectorCost(vmax.getLhs().getType(), vmax.getRhs().getType(),
|
||||||
|
vmax.getOutput().getType());
|
||||||
|
if (auto vavg = dyn_cast<SpatVAvgOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getUnaryVectorCost(vavg.getInput().getType(), vavg.getOutput().getType(),
|
||||||
|
/*scalarOutput=*/true);
|
||||||
|
if (auto relu = dyn_cast<SpatReluOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getUnaryVectorCost(relu.getInput().getType(), relu.getOutput().getType());
|
||||||
|
if (auto sigm = dyn_cast<SpatSigmoidOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getUnaryVectorCost(sigm.getInput().getType(), sigm.getOutput().getType());
|
||||||
|
if (auto softmax = dyn_cast<SpatSoftmaxOp>(&op)) {
|
||||||
|
Cost unary = PimsimSchedulerCostModel::getUnaryVectorCost(softmax.getInput().getType(), softmax.getOutput().getType());
|
||||||
|
return checkedMultiply(unary, static_cast<Cost>(4));
|
||||||
|
}
|
||||||
|
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getTensorMoveCost(extract.getResult().getType());
|
||||||
|
if (auto insert = dyn_cast<tensor::InsertSliceOp>(&op))
|
||||||
|
return PimsimSchedulerCostModel::getTensorMoveCost(insert.getSource().getType());
|
||||||
|
|
||||||
|
Cost nestedCost = 0;
|
||||||
|
for (Region& region : op.getRegions())
|
||||||
|
nestedCost = checkedAdd(nestedCost, getRegionCost(region));
|
||||||
|
return checkedAdd(PimsimSchedulerCostModel::kFallbackOperationCost, nestedCost);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
|
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
|
||||||
@@ -54,6 +270,11 @@ std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
|
|||||||
return (distance + stride - 1) / stride;
|
return (distance + stride - 1) / stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Cost getComputeBodyCost(Region& body) {
|
||||||
|
constexpr Cost kOperationCost = 100;
|
||||||
|
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t countOperationInstances(Operation& op) {
|
uint64_t countOperationInstances(Operation& op) {
|
||||||
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
|
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
|
||||||
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
|
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
|
||||||
@@ -149,7 +370,8 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
|
|||||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
|
projectedCost = checkedAdd(
|
||||||
|
projectedCost, PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(resultType))));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (projectedCost == 0)
|
if (projectedCost == 0)
|
||||||
@@ -162,7 +384,7 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input)
|
|||||||
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||||
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||||
return *projectedCost;
|
return *projectedCost;
|
||||||
return static_cast<Cost>(getSizeInBytes(inputType));
|
return PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(inputType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
|
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
|
||||||
@@ -451,7 +673,7 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
|||||||
continue;
|
continue;
|
||||||
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
|
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||||
if (!inserted.second)
|
if (!inserted.second)
|
||||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
inserted.first->second = checkedAdd(inserted.first->second, edge.transferCost);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||||
|
|||||||
+10
-3
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
|
|||||||
llvm_unreachable("unknown merge scheduler kind");
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
}
|
}
|
||||||
|
|
||||||
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
|
void verifySchedule(const ComputeGraph& graph,
|
||||||
|
const MergeScheduleResult& result,
|
||||||
|
unsigned long crossbarCapacity,
|
||||||
|
size_t processorCount) {
|
||||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||||
|
|
||||||
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
|||||||
|
|
||||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||||
if (sourceCpu != targetCpu)
|
if (sourceCpu != targetCpu)
|
||||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
earliestTargetStart = addOrMax(
|
||||||
|
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
|
||||||
if (targetStart < earliestTargetStart) {
|
if (targetStart < earliestTargetStart) {
|
||||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||||
graph.nodes[edge.source].originalOrder,
|
graph.nodes[edge.source].originalOrder,
|
||||||
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
|||||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||||
entryOp->getContext()});
|
entryOp->getContext()});
|
||||||
}
|
}
|
||||||
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
|
verifySchedule(graph,
|
||||||
|
schedule,
|
||||||
|
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||||
|
options.processorCount);
|
||||||
return schedule;
|
return schedule;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -21,6 +22,63 @@ struct ScheduledTask {
|
|||||||
Time endTime = 0;
|
Time endTime = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MeshModel {
|
||||||
|
size_t rows = 1;
|
||||||
|
size_t cols = 1;
|
||||||
|
long double averageDistance = 0.0L;
|
||||||
|
|
||||||
|
static MeshModel infer(size_t processorCount) {
|
||||||
|
MeshModel model;
|
||||||
|
if (processorCount == 0)
|
||||||
|
return model;
|
||||||
|
|
||||||
|
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
|
||||||
|
if (model.rows == 0)
|
||||||
|
model.rows = 1;
|
||||||
|
while (model.rows > 1 && processorCount % model.rows != 0)
|
||||||
|
--model.rows;
|
||||||
|
model.cols = (processorCount + model.rows - 1) / model.rows;
|
||||||
|
|
||||||
|
auto averageAxisDistance = [](size_t size) -> long double {
|
||||||
|
if (size <= 1)
|
||||||
|
return 0.0L;
|
||||||
|
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
|
||||||
|
};
|
||||||
|
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<size_t, size_t> getCoord(size_t processor) const {
|
||||||
|
return {processor / cols, processor % cols};
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getDistance(size_t lhs, size_t rhs) const {
|
||||||
|
auto [lhsRow, lhsCol] = getCoord(lhs);
|
||||||
|
auto [rhsRow, rhsCol] = getCoord(rhs);
|
||||||
|
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
|
||||||
|
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
|
||||||
|
return rowDistance + colDistance;
|
||||||
|
}
|
||||||
|
|
||||||
|
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
|
||||||
|
if (sourceProcessor == targetProcessor || transferCost == 0)
|
||||||
|
return 0;
|
||||||
|
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
|
||||||
|
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
|
||||||
|
scale = std::max(0.25L, scale);
|
||||||
|
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getCenterDistance(size_t processor) const {
|
||||||
|
auto [row, col] = getCoord(processor);
|
||||||
|
size_t centerRow = rows / 2;
|
||||||
|
size_t centerCol = cols / 2;
|
||||||
|
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
|
||||||
|
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
|
||||||
|
return rowDistance + colDistance;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||||
std::queue<size_t> readySinks;
|
std::queue<size_t> readySinks;
|
||||||
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
|
||||||
|
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||||
const size_t nodeCount = graph.nodes.size();
|
const size_t nodeCount = graph.nodes.size();
|
||||||
const size_t processorCount = options.processorCount;
|
const size_t processorCount = options.processorCount;
|
||||||
if (processorCount == 0)
|
if (processorCount == 0)
|
||||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||||
|
MeshModel mesh = MeshModel::infer(processorCount);
|
||||||
|
|
||||||
verifyOctTableSize(nodeCount, processorCount);
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||||
|
|
||||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time bestEft = 0;
|
Time bestEft = 0;
|
||||||
Time bestOeft = std::numeric_limits<Time>::max();
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
unsigned int bestOverlapCount = 0;
|
unsigned int bestOverlapCount = 0;
|
||||||
|
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
|
||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time dataReady = 0;
|
Time dataReady = 0;
|
||||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
const ScheduledTask& predSchedule = schedules[pred];
|
const ScheduledTask& predSchedule = schedules[pred];
|
||||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
|
||||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
|
|
||||||
Time eft = addOrMax(est, computeCost);
|
Time eft = addOrMax(est, computeCost);
|
||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
size_t centerDistance = mesh.getCenterDistance(processor);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||||
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapCount = overlapCount;
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
}
|
}
|
||||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||||
|
&& centerDistance < bestCenterDistance) {
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapCount = overlapCount;
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
|
}
|
||||||
|
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||||
|
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
|
||||||
|
bestProcessor = processor;
|
||||||
|
bestEst = est;
|
||||||
|
bestEft = eft;
|
||||||
|
bestOeft = oeft;
|
||||||
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
|
|||||||
mlir::MLIRContext* context = nullptr;
|
mlir::MLIRContext* context = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
|
||||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
||||||
|
std::unique_ptr<mlir::Pass> createSpatialLayoutPlanningPass();
|
||||||
|
std::unique_ptr<mlir::Pass> createLowerSpatialPlansPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,8 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
|||||||
void PimAccelerator::registerPasses(int optLevel) const {
|
void PimAccelerator::registerPasses(int optLevel) const {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||||
registerPass(createONNXToSpatialPass);
|
registerPass(createONNXToSpatialPass);
|
||||||
|
registerPass(createSpatialLayoutPlanningPass);
|
||||||
|
registerPass(createLowerSpatialPlansPass);
|
||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createPimBufferizationPass);
|
registerPass(createPimBufferizationPass);
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ from onnx import TensorProto
|
|||||||
|
|
||||||
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
||||||
DTYPES = {
|
DTYPES = {
|
||||||
TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"),
|
TensorProto.FLOAT: ("float", "%.9g", "ONNX_TYPE_FLOAT"),
|
||||||
TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"),
|
TensorProto.DOUBLE: ("double", "%.17g", "ONNX_TYPE_DOUBLE"),
|
||||||
TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"),
|
TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"),
|
||||||
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
||||||
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
||||||
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
||||||
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte
|
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"),
|
||||||
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit
|
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"),
|
||||||
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,295 @@
|
|||||||
|
#!/usr/bin/env python3.13
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||||
|
VALIDATION_DIR = SCRIPT_DIR.parent
|
||||||
|
REPO_ROOT = VALIDATION_DIR.parent
|
||||||
|
if str(VALIDATION_DIR) not in sys.path:
|
||||||
|
sys.path.insert(0, str(VALIDATION_DIR))
|
||||||
|
|
||||||
|
from onnx_utils import _ONNX_TO_NP, onnx_io, write_inputs_to_memory_bin
|
||||||
|
from validate_one import (
|
||||||
|
MODE_COMPILE_ONLY,
|
||||||
|
build_dump_ranges,
|
||||||
|
parse_pim_simulator_outputs,
|
||||||
|
run_pim_simulator,
|
||||||
|
sanitize_output_name,
|
||||||
|
validate_network,
|
||||||
|
)
|
||||||
|
from yolo_real_image_validation import save_tensor_csv
|
||||||
|
|
||||||
|
IMAGENET_MEAN = np.asarray([0.485, 0.456, 0.406], dtype=np.float32)
|
||||||
|
IMAGENET_STD = np.asarray([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
DEFAULT_VGG_MODEL = VALIDATION_DIR / "networks" / "vgg16" / "depth_35" / "vgg16_depth_35.onnx"
|
||||||
|
DEFAULT_RESNET_MODEL = VALIDATION_DIR / "networks" / "resnet" / "resnet18_torchvision.onnx"
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_default_paths():
|
||||||
|
return {
|
||||||
|
"raptor_path": REPO_ROOT / "build_release" / "Release" / "bin" / "onnx-mlir",
|
||||||
|
"onnx_include_dir": REPO_ROOT / "onnx-mlir" / "include",
|
||||||
|
"simulator_dir": REPO_ROOT / "backend-simulators" / "pim" / "pim-simulator",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_path(network: str | None, model: Path | None) -> Path:
|
||||||
|
if model is not None:
|
||||||
|
return model.resolve()
|
||||||
|
if network == "resnet":
|
||||||
|
return DEFAULT_RESNET_MODEL.resolve()
|
||||||
|
if network == "vgg":
|
||||||
|
return DEFAULT_VGG_MODEL.resolve()
|
||||||
|
raise SystemExit("Pass --model or select a default with --network {resnet,vgg}.")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_local_artifacts(args, model_path: Path):
|
||||||
|
validate_network(
|
||||||
|
network_onnx_path=model_path,
|
||||||
|
raptor_path=args.raptor_path,
|
||||||
|
onnx_include_dir=args.onnx_include_dir,
|
||||||
|
simulator_dir=args.simulator_dir,
|
||||||
|
crossbar_size=args.crossbar_size,
|
||||||
|
crossbar_count=args.crossbar_count,
|
||||||
|
core_count=args.core_count,
|
||||||
|
command_timeout_seconds=args.command_timeout_seconds,
|
||||||
|
mode=MODE_COMPILE_ONLY,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_existing_artifacts(model_dir: Path):
|
||||||
|
required_paths = [
|
||||||
|
model_dir / "runner" / "build" / "runner",
|
||||||
|
model_dir / "raptor" / "pim" / "config.json",
|
||||||
|
model_dir / "raptor" / "pim" / "memory.bin",
|
||||||
|
]
|
||||||
|
missing = [str(path) for path in required_paths if not path.exists()]
|
||||||
|
if missing:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Missing compiled local artifacts. Re-run without --skip-compile or restore these paths:\n "
|
||||||
|
+ "\n ".join(missing)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_classification_image(image_path: Path) -> tuple[Image.Image, np.ndarray]:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
scale = 256.0 / min(width, height)
|
||||||
|
resized_size = (
|
||||||
|
max(1, int(round(width * scale))),
|
||||||
|
max(1, int(round(height * scale))),
|
||||||
|
)
|
||||||
|
resized = image.resize(resized_size, Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
left = (resized.width - 224) // 2
|
||||||
|
top = (resized.height - 224) // 2
|
||||||
|
cropped = resized.crop((left, top, left + 224, top + 224))
|
||||||
|
|
||||||
|
array = np.asarray(cropped, dtype=np.float32) / 255.0
|
||||||
|
array = (array - IMAGENET_MEAN) / IMAGENET_STD
|
||||||
|
chw = np.transpose(array, (2, 0, 1))
|
||||||
|
tensor = np.expand_dims(chw.astype(np.float32, copy=False), axis=0)
|
||||||
|
return image, tensor
|
||||||
|
|
||||||
|
|
||||||
|
def load_labels(labels_path: Path | None) -> list[str] | None:
|
||||||
|
if labels_path is None:
|
||||||
|
return None
|
||||||
|
labels = [line.strip() for line in labels_path.read_text().splitlines()]
|
||||||
|
return labels or None
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(values: np.ndarray) -> np.ndarray:
|
||||||
|
shifted = values - np.max(values)
|
||||||
|
exp = np.exp(shifted)
|
||||||
|
denom = exp.sum()
|
||||||
|
if not math.isfinite(float(denom)) or denom <= 0.0:
|
||||||
|
raise RuntimeError("Softmax received non-finite output scores.")
|
||||||
|
return exp / denom
|
||||||
|
|
||||||
|
|
||||||
|
def decode_classification_output(output: np.ndarray, labels: list[str] | None, top_k: int):
|
||||||
|
scores = np.asarray(output, dtype=np.float64).reshape(-1)
|
||||||
|
probabilities = softmax(scores)
|
||||||
|
limit = min(top_k, probabilities.size)
|
||||||
|
top_indices = np.argsort(probabilities)[-limit:][::-1]
|
||||||
|
results = []
|
||||||
|
for index in top_indices:
|
||||||
|
label = None
|
||||||
|
if labels is not None and 0 <= int(index) < len(labels):
|
||||||
|
label = labels[int(index)]
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"index": int(index),
|
||||||
|
"label": label,
|
||||||
|
"probability": float(probabilities[int(index)]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def render_result_line(result) -> str:
|
||||||
|
name = result["label"] if result["label"] else f'class {result["index"]}'
|
||||||
|
return f'{name}: {result["probability"] * 100.0:.2f}%'
|
||||||
|
|
||||||
|
|
||||||
|
def draw_classification_panel(image: Image.Image, results, output_path: Path):
|
||||||
|
annotated = image.copy()
|
||||||
|
draw = ImageDraw.Draw(annotated)
|
||||||
|
lines = [render_result_line(result) for result in results]
|
||||||
|
if not lines:
|
||||||
|
lines = ["No predictions"]
|
||||||
|
|
||||||
|
padding = 10
|
||||||
|
line_gap = 4
|
||||||
|
max_width = 0
|
||||||
|
line_heights = []
|
||||||
|
for line in lines:
|
||||||
|
left, top, right, bottom = draw.textbbox((0, 0), line)
|
||||||
|
max_width = max(max_width, right - left)
|
||||||
|
line_heights.append(bottom - top)
|
||||||
|
|
||||||
|
panel_height = padding * 2 + sum(line_heights) + line_gap * (len(lines) - 1)
|
||||||
|
panel_width = padding * 2 + max_width
|
||||||
|
origin_x = 12
|
||||||
|
origin_y = 12
|
||||||
|
draw.rounded_rectangle(
|
||||||
|
(origin_x, origin_y, origin_x + panel_width, origin_y + panel_height),
|
||||||
|
radius=10,
|
||||||
|
fill=(0, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
y = origin_y + padding
|
||||||
|
for line, line_height in zip(lines, line_heights):
|
||||||
|
draw.text((origin_x + padding, y), line, fill=(255, 255, 255))
|
||||||
|
y += line_height + line_gap
|
||||||
|
|
||||||
|
annotated.save(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def run_reference_and_simulator(args, model_path: Path, tensor: np.ndarray):
|
||||||
|
model_dir = model_path.parent
|
||||||
|
runner_build_dir = model_dir / "runner" / "build"
|
||||||
|
runner_path = runner_build_dir / "runner"
|
||||||
|
pim_dir = model_dir / "raptor" / "pim"
|
||||||
|
simulation_dir = model_dir / "classification_demo" / "simulation"
|
||||||
|
reference_dir = model_dir / "classification_demo" / "reference"
|
||||||
|
inputs_dir = model_dir / "classification_demo" / "inputs"
|
||||||
|
|
||||||
|
simulation_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
reference_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
inputs_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
input_descriptors, output_descriptors = onnx_io(model_path)
|
||||||
|
if len(input_descriptors) != 1:
|
||||||
|
raise RuntimeError(f"Expected one classification input tensor, found {len(input_descriptors)}")
|
||||||
|
if len(output_descriptors) != 1:
|
||||||
|
raise RuntimeError(f"Expected one classification output tensor, found {len(output_descriptors)}")
|
||||||
|
|
||||||
|
input_index, _input_name, _input_dtype, input_shape = input_descriptors[0]
|
||||||
|
if list(tensor.shape) != list(input_shape):
|
||||||
|
raise RuntimeError(f"Preprocessed tensor shape {list(tensor.shape)} does not match model input {input_shape}")
|
||||||
|
|
||||||
|
input_csv = inputs_dir / "in0.csv"
|
||||||
|
save_tensor_csv(tensor, input_csv)
|
||||||
|
|
||||||
|
runner_cmd = [
|
||||||
|
str(runner_path),
|
||||||
|
f"--in{input_index}-csv-file",
|
||||||
|
str(input_csv),
|
||||||
|
f"--in{input_index}-shape",
|
||||||
|
"x".join(str(dim) for dim in tensor.shape),
|
||||||
|
"--save-csv-dir",
|
||||||
|
str(reference_dir),
|
||||||
|
]
|
||||||
|
subprocess.run(runner_cmd, cwd=runner_build_dir, check=True)
|
||||||
|
|
||||||
|
write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", [tensor])
|
||||||
|
dump_ranges = build_dump_ranges(pim_dir / "config.json", output_descriptors)
|
||||||
|
output_bin_path = simulation_dir / "out.bin"
|
||||||
|
run_pim_simulator(
|
||||||
|
args.simulator_dir,
|
||||||
|
pim_dir,
|
||||||
|
output_bin_path,
|
||||||
|
dump_ranges,
|
||||||
|
timeout_sec=args.command_timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_index, output_name, output_dtype_code, output_shape = output_descriptors[0]
|
||||||
|
output_dtype = np.dtype(_ONNX_TO_NP[output_dtype_code])
|
||||||
|
reference_csv = reference_dir / f"output{output_index}_{sanitize_output_name(output_name)}.csv"
|
||||||
|
reference_output = np.loadtxt(reference_csv, delimiter=",", dtype=output_dtype).reshape(output_shape)
|
||||||
|
simulator_output = parse_pim_simulator_outputs(output_bin_path, output_descriptors)[0]
|
||||||
|
return reference_output, simulator_output
|
||||||
|
|
||||||
|
|
||||||
|
def print_topk(title: str, results):
|
||||||
|
print(title)
|
||||||
|
for rank, result in enumerate(results, start=1):
|
||||||
|
label_text = result["label"] if result["label"] else f'class {result["index"]}'
|
||||||
|
print(f' {rank}. {label_text} ({result["probability"] * 100.0:.2f}%) [index={result["index"]}]')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
defaults = resolve_default_paths()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Run a VGG or ResNet ONNX model through the Raptor simulator and annotate the image with top classification results.")
|
||||||
|
parser.add_argument("--model", type=Path, default=None)
|
||||||
|
parser.add_argument("--network", choices=("resnet", "vgg"), default=None)
|
||||||
|
parser.add_argument("--image", type=Path, required=True)
|
||||||
|
parser.add_argument("--labels", type=Path, default=None)
|
||||||
|
parser.add_argument("--output", type=Path, required=True)
|
||||||
|
parser.add_argument("--raptor-path", type=Path, default=defaults["raptor_path"])
|
||||||
|
parser.add_argument("--onnx-include-dir", type=Path, default=defaults["onnx_include_dir"])
|
||||||
|
parser.add_argument("--simulator-dir", type=Path, default=defaults["simulator_dir"])
|
||||||
|
parser.add_argument("--crossbar-size", type=int, default=2048)
|
||||||
|
parser.add_argument("--crossbar-count", type=int, default=256)
|
||||||
|
parser.add_argument("--core-count", type=int, default=1000)
|
||||||
|
parser.add_argument("--top-k", type=int, default=5)
|
||||||
|
parser.add_argument("--command-timeout-seconds", type=float, default=7200.0)
|
||||||
|
parser.add_argument("--skip-compile", action="store_true")
|
||||||
|
parser.add_argument("--verbose", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.model = resolve_model_path(args.network, args.model)
|
||||||
|
args.image = args.image.resolve()
|
||||||
|
args.output = args.output.resolve()
|
||||||
|
args.labels = args.labels.resolve() if args.labels else None
|
||||||
|
args.raptor_path = args.raptor_path.resolve()
|
||||||
|
args.onnx_include_dir = args.onnx_include_dir.resolve()
|
||||||
|
args.simulator_dir = args.simulator_dir.resolve()
|
||||||
|
|
||||||
|
if not args.skip_compile:
|
||||||
|
ensure_local_artifacts(args, args.model)
|
||||||
|
else:
|
||||||
|
ensure_existing_artifacts(args.model.parent)
|
||||||
|
|
||||||
|
original_image, tensor = preprocess_classification_image(args.image)
|
||||||
|
labels = load_labels(args.labels)
|
||||||
|
reference_output, simulator_output = run_reference_and_simulator(args, args.model, tensor)
|
||||||
|
reference_results = decode_classification_output(reference_output, labels, args.top_k)
|
||||||
|
simulator_results = decode_classification_output(simulator_output, labels, args.top_k)
|
||||||
|
|
||||||
|
print_topk("Reference top-k:", reference_results)
|
||||||
|
print_topk("Simulator top-k:", simulator_results)
|
||||||
|
|
||||||
|
reference_scores = np.asarray(reference_output, dtype=np.float64).reshape(-1)
|
||||||
|
simulator_scores = np.asarray(simulator_output, dtype=np.float64).reshape(-1)
|
||||||
|
max_abs_diff = float(np.max(np.abs(reference_scores - simulator_scores)))
|
||||||
|
print(f"Max absolute score diff: {max_abs_diff:.6e}")
|
||||||
|
|
||||||
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
draw_classification_panel(original_image, simulator_results, args.output)
|
||||||
|
print(f"Annotated image saved to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user