11 Commits

Author SHA1 Message Date
NiccoloN 7f3c7464b4 update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Waiting to run
2026-05-22 22:16:19 +02:00
NiccoloN c77ffa9c56 better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
2026-05-22 21:52:28 +02:00
NiccoloN 495186503c fix cmake magic once again 2026-05-22 19:21:56 +02:00
NiccoloN 2c1da813b5 fix much stuff 2026-05-22 18:53:38 +02:00
NiccoloN 8337a11ce9 automatic code reformat 2026-05-22 15:23:48 +02:00
ilgeco d136136d22 Fix add of input in random order for compute_batch
Validate Operations / validate-operations (push) Waiting to run
2026-05-22 15:21:02 +02:00
NiccoloN 074eb183c7 saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Waiting to run
2026-05-22 07:27:54 +02:00
NiccoloN 43ed3914b8 better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Waiting to run
2026-05-22 06:56:39 +02:00
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:13:54 +02:00
69 changed files with 2546 additions and 2159 deletions
+92 -24
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor) project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir # Materialize a CMake shim directory
function(raptor_ensure_symlink link_path target_path) function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(link_parent "${link_path}" DIRECTORY) get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}") if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR "Directory not found: ${link_parent}") message(FATAL_ERROR
endif() "External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
) )
endif() endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
file(CREATE_LINK
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endforeach ()
endfunction() endfunction()
raptor_ensure_symlink( raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
) )
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM" raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
) )
# Patch onnx-mlir sources for PIM accelerator support. # Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present # Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos) string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1) if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}") message(STATUS "Patch already applied: ${description}")
return() return()
endif() endif ()
# Anchor must exist for the patch to be applicable # Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos) string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1) if (anchor_pos EQUAL -1)
message(FATAL_ERROR message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n" "Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n" " Patch : ${description}\n"
" File : ${file_path}\n" " File : ${file_path}\n"
" Anchor: ${anchor}" " Anchor: ${anchor}"
) )
endif() endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}") string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}") file(WRITE "${file_path}" "${patched}")
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
if in_path.contains(&waiting_for) { if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap(); let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..]; let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle let cycle_str = cycle
.iter() .iter()
.map(|c| c.to_string()) .map(format_core)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(" -> "); .join(" -> ");
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
.copied() .copied()
.chain(std::iter::once(waiting_for)) .chain(std::iter::once(waiting_for))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle let states_msg = cycle
.iter() .iter()
.filter_map(|core| { .filter_map(|core| {
states.get(core).map(|state| match state { states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => { CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target) format!("core {} send {}B -> {}", core - 1, size, target - 1)
} }
CoreState::ReceivingFrom(source, size) => { CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source) format!("core {} recv {}B <- {}", core - 1, size, source - 1)
} }
CoreState::Working => format!("core {} working", core), CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core), CoreState::Halted => format!("core {} halted", core - 1),
}) })
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
+53
View File
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
set(PIM_GENERATED_PATH_SHIM_TARGET "")
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
function(add_pim_generated_path_shim relative_path)
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
add_custom_command(
OUTPUT "${shim_file}"
DEPENDS "${real_file}"
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
VERBATIM
)
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE pim_generated_path_scan_sources
CONFIGURE_DEPENDS
"${PIM_SRC_ROOT}/*.cpp"
"${PIM_SRC_ROOT}/*.hpp"
)
set(pim_generated_path_shims)
foreach (source_file IN LISTS pim_generated_path_scan_sources)
file(READ "${source_file}" source_contents)
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
foreach (inc_match IN LISTS source_inc_matches)
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
list(APPEND pim_generated_path_shims "${relative_inc_path}")
endforeach ()
endforeach ()
list(REMOVE_DUPLICATES pim_generated_path_shims)
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
add_pim_generated_path_shim("${relative_inc_path}")
endforeach ()
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
endif ()
set(PIM_PUBLIC_INCLUDE_DIRS set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include ${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
function(add_pim_library name) function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN}) add_onnx_mlir_library(${name} STATIC ${ARGN})
if (PIM_GENERATED_PATH_SHIM_TARGET)
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
endif ()
endfunction() endfunction()
add_subdirectory(Dialect) add_subdirectory(Dialect)
+1 -1
View File
@@ -264,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
return mlir::failure(); return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge); value = resolveAlias(subviewOp.getSource(), knowledge);
continue; continue;
} }
+25
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements; return numElements;
} }
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
+11
View File
@@ -1,8 +1,13 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape); llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape); int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
+7 -4
View File
@@ -21,13 +21,15 @@ namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy> template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex); auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false; bool found = false;
parentOp.walk([&](mlir::Operation* op) { parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op)) if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeight() == weightArg; found |= mvmOp.getWeight() == *weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op)) else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == weightArg; found |= vmmOp.getWeight() == *weightArg;
}); });
return found; return found;
} }
@@ -38,7 +40,8 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpO
llvm::SmallSet<unsigned, 8> visited; llvm::SmallSet<unsigned, 8> visited;
auto walkWeight = [&](mlir::Value weight) { auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) { for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight) auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg || *weightArg != weight)
continue; continue;
if (visited.insert(weightIndex).second) if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex)); callback(parentOp->getOpOperand(weightIndex));
+3 -3
View File
@@ -13,7 +13,8 @@
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
struct CappedDiagnosticReporter { struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {} explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn> template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) { void report(mlir::Operation* op, EmitFn&& emit) {
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const { void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures) if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
<< failureDescription;
} }
bool hasFailure() const { return numFailures != 0; } bool hasFailure() const { return numFailures != 0; }
+41 -7
View File
@@ -28,23 +28,47 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds; return laneCoreIds;
} }
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
if (Value mapped = mapper.lookupOrNull(value))
return mapped;
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
}
Operation* definingOp = value.getDefiningOp();
assert(definingOp && "expected captured value to be defined by an operation");
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
for (Value operand : definingOp->getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
Operation* cloned = builder.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static void cloneScalarizedLaneBody(OpBuilder& builder, static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp, pim::PimCoreBatchOp coreBatchOp,
unsigned lane, unsigned lane,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front(); Block& oldBlock = coreBatchOp.getBody().front();
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount()); size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size(); size_t weightCount = coreBatchOp.getWeights().size();
IRMapping mapper; IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) { if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder)); mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
continue; continue;
} }
if (argIndex <= weightCount) { if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]); auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
continue; continue;
} }
@@ -57,8 +81,10 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
if (isa<pim::PimHaltOp>(op)) if (isa<pim::PimHaltOp>(op))
continue; continue;
for (Value operand : op.getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create( pim::PimSendOp::create(
builder, builder,
sendBatchOp.getLoc(), sendBatchOp.getLoc(),
@@ -78,7 +104,6 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
} }
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) { if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
auto scalarReceive = pim::PimReceiveOp::create( auto scalarReceive = pim::PimReceiveOp::create(
builder, builder,
receiveBatchOp.getLoc(), receiveBatchOp.getLoc(),
@@ -106,8 +131,8 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
builder, builder,
memcpBatchOp.getLoc(), memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(), memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()), mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()), mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr()); memcpBatchOp.getSizeAttr());
@@ -141,7 +166,16 @@ LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
auto scalarCore = auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); SmallVector<Type> weightTypes;
SmallVector<Location> weightLocs;
weightTypes.reserve(weights.size());
weightLocs.reserve(weights.size());
for (Value weight : weights) {
weightTypes.push_back(weight.getType());
weightLocs.push_back(weight.getLoc());
}
Block* block =
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
for (unsigned lane : lanes) for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
+18 -28
View File
@@ -41,23 +41,10 @@ using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm; using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); size_t allocSize = getShapedTypeSizeInBytes(type);
MemEntry memEntry = {0, allocSize}; MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
@@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const { const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge); size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size(); size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
/ receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds())) for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
} }
@@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge); size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size(); size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
/ sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds())) for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
} }
@@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
int64_t axis = concatOp.getAxis(); int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge); size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1; size_t outerCount = 1;
@@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 2; instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
instruction.r1 = 1; instruction.r1 = 1;
instruction.r2OrImm = 1; instruction.r2OrImm = 1;
instruction.generic1 = 1; instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vrelu; instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vtanh; instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vsigm; instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput())); instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
instruction.opcode = pim_binary::Opcode::vsoftmax; instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0; instruction.rd = 0;
instruction.r1 = 1; instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput())); instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
auto srcType = cast<ShapedType>(transposeOp.getInput().getType()); auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape(); auto srcShape = srcType.getShape();
size_t rank = srcShape.size(); size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
size_t totalElements = srcType.getNumElements(); size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i]. // Read permutation. Destination dim i corresponds to source dim perm[i].
+9 -9
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir { namespace onnx_mlir {
@@ -15,13 +15,13 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen), llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler( llvm::cl::opt<PimMergeSchedulerType>
"pim-merge-scheduler", pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft), llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen", pimOnlyCodegen("pim-only-codegen",
+1 -1
View File
@@ -208,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
int64_t numCols = shape[1]; int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>; using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithValues( detail::invokeWithValues(std::forward<BodyFn>(body),
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {}); detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = detail::invokeWithValues( auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {}); detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -422,9 +422,13 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size()); vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
vmmOutputs.push_back(spatial::SpatVMMOp::create( auto weightArg = computeOp.getWeightArgument(aHSliceId);
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId))); auto inputArg = computeOp.getInputArgument(aHSliceId);
if (!weightArg || !inputArg)
return failure();
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
}
if (vmmOutputs.empty()) { if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure(); return failure();
@@ -558,29 +562,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body); rewriter.setInsertionPointToEnd(body);
Value lane = batchOp.getLaneArgument(); auto lane = batchOp.getLaneArgument();
Value weight = batchOp.getWeightArgument(0); auto weight = batchOp.getWeightArgument(0);
Value packedInput = batchOp.getInputArgument(0); auto packedInput = batchOp.getInputArgument(0);
Value packedOutput = batchOp.getOutputArgument(0); auto packedOutput = batchOp.getOutputArgument(0);
if (!lane || !weight || !packedInput || !packedOutput)
return failure();
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value row = Value row =
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides) tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
.getResult(); .getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult(); Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
Value laneResult = vmmResult; Value laneResult = vmmResult;
if (sharedBias) if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, tensor::ParallelInsertSliceOp::create(
unitStrides); rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults()); rewriter.replaceOp(gemmOp, batchOp.getResults());
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end()); return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
} }
static Value collapseBatchDims(Value value, static Value
int64_t batchSize, collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType()); auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3) if (type.getRank() == 2 || type.getRank() == 3)
return value; return value;
auto collapsedType = auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
SmallVector<ReassociationIndices> reassociation = { ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {}, ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim); reassociation.front().push_back(dim);
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
return collapseCompute.getResult(0); return collapseCompute.getResult(0);
} }
static Value expandBatchDims(Value value, static Value
RankedTensorType outputType, expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType) if (cast<RankedTensorType>(value.getType()) == outputType)
return value; return value;
SmallVector<ReassociationIndices> reassociation = { SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {}, ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank)}, ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
for (size_t dim = 0; dim < batchRank; ++dim) for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim)); reassociation.front().push_back(static_cast<int64_t>(dim));
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar(); Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC = Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody()); rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar(); Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH = Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody()); rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar(); Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front(); Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW = Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW}; SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice = Value inputSlice =
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(resizeOp,
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor."); "resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -27,13 +27,16 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser); return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
} }
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) { static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp()); return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
} }
template <typename ComputeOpTy> template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input)) if (!isWeightLikeComputeOperand(input))
continue; continue;
@@ -94,8 +97,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
} }
llvm::append_range(newBlockArgTypes, newInputTypes); llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs); llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock = auto* newBlock = rewriter.createBlock(
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
@@ -104,20 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bodyRewriter.setInsertionPointToStart(newBlock); bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper; IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex)); auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0; size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx); auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
if (!promoteInput[oldInputIdx]) { if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++)); auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
mapper.map(*oldArg, *newInputArg);
continue; continue;
} }
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue)) if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue); mapper.map(*oldArg, *clonedValue);
} }
for (Operation& op : oldBlock.without_terminator()) for (Operation& op : oldBlock.without_terminator())
@@ -184,12 +197,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())), rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights, newWeights,
newInputs); newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes; SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs; SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults()); newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults()); newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(compute.getLaneArgument().getType()); newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc()); newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : newWeights) { for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType()); newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc()); newBlockArgLocs.push_back(weight.getLoc());
@@ -197,8 +213,11 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
llvm::append_range(newBlockArgTypes, newInputTypes); llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs); llvm::append_range(newBlockArgLocs, newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) { for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType); newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc()); newBlockArgLocs.push_back(outputArg->getLoc());
} }
auto* newBlock = rewriter.createBlock( auto* newBlock = rewriter.createBlock(
@@ -211,24 +230,41 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bodyRewriter.setInsertionPointToStart(newBlock); bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper; IRMapping mapper;
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument()); auto newLaneArg = newCompute.getLaneArgument();
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) if (!newLaneArg)
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex)); return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0; size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx); auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
if (!promoteInput[oldInputIdx]) { if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++)); auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
mapper.map(*oldArg, *newInputArg);
continue; continue;
} }
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue)) if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand"); return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue); mapper.map(*oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
} }
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock) for (Operation& op : oldBlock)
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
@@ -1,12 +1,14 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -97,22 +99,75 @@ static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveT
return success(); return success();
} }
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
if (!returnOp)
return failure();
return result.getUses().begin()->getOperandNumber();
}
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset =
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
}
if (!totalOffset)
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
return totalOffset;
}
} // namespace } // namespace
LogicalResult LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc(); Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front(); Block& oldBlock = computeBatchOp.getBody().front();
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
"materialize explicit communication before lowering to PIM");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator()); auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (!oldYield || oldYield.getNumOperands() != 0) auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield"); if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs; SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty()) if (!computeBatchOp.getInputs().empty())
@@ -128,9 +183,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())}); {static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
returnOperandIndices[resultIndex] = *returnOperandIndex;
}
}
SmallVector<Type> blockArgTypes; SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs; SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) { unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType()); blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc()); blockArgLocs.push_back(arg.getLoc());
} }
@@ -139,11 +207,20 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper; IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument()); auto oldLaneArg = computeBatchOp.getLaneArgument();
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) if (!oldLaneArg)
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex)); return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
if (!oldWeightArg)
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
}
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) { for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex); auto oldArg = computeBatchOp.getInputArgument(inputIndex);
if (!oldArg)
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex); BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType()); auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
@@ -156,7 +233,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg)) getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput(); .getOutput();
mapper.map(oldArg, copied); mapper.map(*oldArg, copied);
} }
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
@@ -178,11 +255,55 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
return copied; return copied;
}; };
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
Value& hostOutputTensor = hostOutputTensors[resultIndex];
if (hostOutputTensor)
return hostOutputTensor;
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
return hostOutputTensor;
};
rewriter.setInsertionPointToEnd(newBlock); rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) { for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op)) if (isa<spatial::SpatYieldOp>(op))
continue; continue;
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
if (!firstOutputArg)
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
for (Operation& nestedOp : parallelOp.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
if (!insertSlice)
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &oldBlock)
return insertSlice.emitOpError("expected compute_batch output block argument destination");
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource));
}
continue;
}
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds()); FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds)) if (failed(targetCoreIds))
@@ -1,10 +0,0 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp Common.cpp
ComputeLikeRegionUtils.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp CoreLoweringPatterns.cpp
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue; return returnValue;
} }
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
} }
@@ -20,8 +20,6 @@ namespace onnx_mlir {
*/ */
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T> template <class T>
@@ -1,3 +1,5 @@
#include <cassert>
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -29,9 +31,17 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
unsigned inputIndex, unsigned inputIndex,
Value replacement) { Value replacement) {
Block& body = owner->getRegion(0).front(); Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner) BlockArgument bodyArgument;
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex) if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex); auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg;
}
else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg;
}
unsigned bodyArgIndex = bodyArgument.getArgNumber(); unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner); rewriter.startOpModification(owner);
@@ -6,9 +6,9 @@
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -131,8 +131,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
IRMapping mapping; IRMapping mapping;
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
mapping.map(computeOp.getWeightArgument(weightIndex), weight); auto weightArg = computeOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) { for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder); cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping); Operation* clonedOp = rewriter.clone(op, mapping);
@@ -148,15 +152,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace } // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) { LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
if (!llvm::is_contained(state.operationsToRemove, op)) IRRewriter& rewriter,
state.operationsToRemove.push_back(op); OperationFolder& constantFolder) {
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder)) if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
return success(); return success();
SmallVector<Operation*> helperChain; SmallVector<Operation*> helperChain;
@@ -167,31 +168,33 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator()); auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
BlockArgument blockArg = computeOp.getInputArgument(inputIndex); auto blockArg = computeOp.getInputArgument(inputIndex);
if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during lowering");
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp()); auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (receiveOp && !blockArg.use_empty()) { if (receiveOp && !blockArg->use_empty()) {
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg.getType()); auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
Value received = Value received =
PimReceiveOp::create( PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
.getOutput(); .getOutput();
blockArg.replaceAllUsesWith(received); blockArg->replaceAllUsesWith(received);
markOpToRemove(state, receiveOp); markOpToRemove(receiveOp);
continue; continue;
} }
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp()); auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
if (receiveTensorOp && !blockArg.use_empty()) { if (receiveTensorOp && !blockArg->use_empty()) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds()); FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
if (failed(sourceCoreIds)) if (failed(sourceCoreIds))
return receiveTensorOp.emitOpError("expected constant sourceCoreIds"); return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds) for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId); sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg.getType()); auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
Value received = PimReceiveTensorOp::create(rewriter, Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(), receiveTensorOp.getLoc(),
@@ -199,8 +202,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
outputBuffer, outputBuffer,
rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput(); .getOutput();
blockArg.replaceAllUsesWith(received); blockArg->replaceAllUsesWith(received);
markOpToRemove(state, receiveTensorOp); markOpToRemove(receiveTensorOp);
} }
} }
@@ -211,9 +214,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (result.use_empty()) if (result.use_empty())
continue; continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult = ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter); lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure) if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure(); return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled) if (returnPathResult == ReturnPathLoweringResult::Handled)
@@ -237,19 +239,19 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (!computeOp.getWeights().empty()) if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter, auto coreOp = PimCoreOp::create(
loc, rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
BlockArgument blockArg = computeOp.getInputArgument(inputIndex); auto blockArg = computeOp.getInputArgument(inputIndex);
if (blockArg.use_empty()) if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during input materialization");
if (blockArg->use_empty())
continue; continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) { if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder)); blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
continue; continue;
} }
@@ -261,13 +263,13 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
PimMemCopyHostToDevOp::create(rewriter, PimMemCopyHostToDevOp::create(rewriter,
loc, loc,
outputBuffer.getType(), outputBuffer.getType(),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
outputBuffer, outputBuffer,
input, input,
getTensorSizeInBytesAttr(rewriter, input)) getTensorSizeInBytesAttr(rewriter, input))
.getOutput(); .getOutput();
blockArg.replaceAllUsesWith(copied); blockArg->replaceAllUsesWith(copied);
} }
if (!computeOp.getInputs().empty()) if (!computeOp.getInputs().empty())
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size()); block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
@@ -1,23 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
mlir::OperationFolder& constantFolder;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -77,8 +77,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgValue = spatCompute.getInputArgument(*inputIndex); auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty()) if (BBArgValue->use_empty())
continue; continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
@@ -95,8 +97,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex); auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty()) if (BBArgValue->use_empty())
continue; continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
@@ -141,152 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
}; };
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
Value hostConstant = constantOp.getResult();
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
}
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
constUses.set(hostConstant);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
constUses.set(hostConstant);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly. // Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> { struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -363,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace } // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) { void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>( patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -9,9 +9,9 @@
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -44,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
pim::PimTransposeOp>(op); pim::PimTransposeOp>(op);
} }
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) { static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str(); std::string name = baseName.str();
unsigned suffix = 0; unsigned suffix = 0;
@@ -390,9 +385,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
} // namespace } // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp, void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue; Value currentReturnValue = returnValue;
@@ -427,8 +420,8 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
} }
} }
ReturnPathLoweringResult lowerProducedValueReturnPath( raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) { Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
Location loc = producerOp->getLoc(); Location loc = producerOp->getLoc();
OperationFolder constantFolder(producerOp->getContext()); OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType()); auto storedTensorType = cast<TensorType>(storedValue.getType());
@@ -437,13 +430,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
Value currentStoredValue = storedValue; Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue); cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain) for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op); markOpToRemove(op);
auto storedType = cast<ShapedType>(currentStoredValue.getType()); auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
if (auto storedOp = currentStoredValue.getDefiningOp()) if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp); rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter, emitHostCopy(rewriter,
loc, loc,
outputTensor, outputTensor,
@@ -462,9 +455,9 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
rewriter.setInsertionPointAfterValue(storedValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter, emitHostCopy(rewriter,
loc, loc,
outputTensor, outputTensor,
@@ -478,13 +471,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
} }
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8; size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
for (Operation* concatOp : concatReturnUse->concatChain) for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp); markOpToRemove(concatOp);
if (concatReturnUse->helperChain.empty()) { if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(storedValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType()); auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter, emitHostCopy(rewriter,
@@ -505,7 +498,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Failure;
} }
rewriter.setInsertionPointAfterValue(storedValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType()); auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
@@ -553,12 +546,12 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
return ReturnPathLoweringResult::NotReturnPath; return ReturnPathLoweringResult::NotReturnPath;
} }
ReturnPathLoweringResult lowerComputeResultReturnPath( raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter); return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
} }
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) { void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op) if (!op)
return; return;
@@ -575,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
if (isReturnHelperChainOp(op)) { if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0); Value source = op->getOperand(0);
markOpToRemove(state, op); markOpToRemove(op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return; return;
} }
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp); markOpToRemove(computeOp);
if (!computeOp.getInputs().empty()) if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs()) for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
@@ -589,33 +582,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
} }
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) { if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp); markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands()) for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return; return;
} }
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) { if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp); markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs()) for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return; return;
} }
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) { if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp); markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs()) for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return; return;
} }
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) { if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
markOpToRemove(state, receiveOp); markOpToRemove(receiveOp);
return; return;
} }
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op)) if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
markOpToRemove(state, receiveTensorOp); markOpToRemove(receiveTensorOp);
}; };
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
@@ -624,7 +617,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
size_t orderWithinReturn = it.index(); size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp(); Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc); Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain); markOwnedReturnChain(returnOperand, markOwnedReturnChain);
} }
@@ -1,43 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include <functional>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
struct ReturnPathState {
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
mlir::IRRewriter& rewriter,
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
mlir::Value producedValue,
mlir::Value storedValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
@@ -14,7 +14,6 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@@ -24,54 +23,28 @@
#include <cassert> #include <cassert>
#include <utility> #include <utility>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" #include "Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" #include "Pass/PIMPasses.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir; using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace pim; using namespace pim;
namespace onnx_mlir { namespace onnx_mlir {
namespace raptor {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> { } // namespace raptor
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
};
} // namespace
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
@@ -151,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
} }
void SpatialToPimPass::runOnOperation() { void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
coreId = 0; coreId = 0;
outputTensors.clear();
operationsToRemove.clear();
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
@@ -198,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors); addReturnOutputBuffers(returnOp, rewriter);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core"); computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure(); signalPassFailure();
return; return;
@@ -218,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) { for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp); markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure(); signalPassFailure();
return; return;
@@ -267,14 +240,8 @@ void SpatialToPimPass::runOnOperation() {
} }
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); replaceReturnWithOutputBuffers(returnOp, rewriter);
eraseOpsToRemove();
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
RewritePatternSet finalTensorPackingPatterns(ctx); RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns); populateTensorPackingPatterns(finalTensorPackingPatterns);
@@ -316,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0"); dumpModule(moduleOp, "pim0");
} }
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) { funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType()); auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
@@ -350,16 +317,17 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
}); });
} }
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType()); auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat()) if (!hasByteSizedElementType(elementType))
return; return;
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = getElementTypeSizeInBytes(elementType);
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -394,11 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
return success(); return success();
} }
void SpatialToPimPass::markOpToRemove(Operation* op) { void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op)) if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op); operationsToRemove.push_back(op);
} }
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); } void raptor::SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,72 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace raptor {
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
llvm::SmallVector<OutputTensorFactory> outputTensors;
size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
mlir::Value producedValue,
mlir::Value storedValue,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
void markOpToRemove(mlir::Operation* op);
void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
};
} // namespace raptor
} // namespace onnx_mlir
+2 -2
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include <string> #include <string>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
+4 -5
View File
@@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<O
return parser.parseRParen(); return parser.parseRParen();
} }
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter); printCompressedValueList(printer, arguments, delimiter);
printer << " = "; printer << " = ";
printCompressedValueList(printer, operands, delimiter); printCompressedValueList(printer, operands, delimiter);
@@ -82,10 +83,8 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) { switch (currentDelimiter) {
case ListDelimiter::Paren: case ListDelimiter::Paren: return parser.parseRParen();
return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare();
case ListDelimiter::Square:
return parser.parseRSquare();
} }
llvm_unreachable("unsupported delimiter"); llvm_unreachable("unsupported delimiter");
}; };
+21 -10
View File
@@ -1,11 +1,14 @@
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -40,7 +43,18 @@ static bool isDefinedInsideRegion(Value value, Region& region) {
static bool isConstantExternalValue(Value value) { static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp(); Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>(); if (!definingOp)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
if (!getGlobalOp)
return false;
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant();
} }
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
@@ -52,8 +66,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|| isExplicitHostOperand(op, operand.getOperandNumber())) || isExplicitHostOperand(op, operand.getOperandNumber()))
continue; continue;
InFlightDiagnostic diagnostic = InFlightDiagnostic diagnostic = ownerOp->emitOpError()
ownerOp->emitOpError() << kind << " body may only directly reference external constants"; << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true; hasFailure = true;
@@ -139,10 +153,9 @@ LogicalResult PimCoreOp::verify() {
Block& block = getBody().front(); Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size()) if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight"); return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly"); return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
} }
@@ -155,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() {
return emitError("core_batch body must have lane, weight, and input block arguments"); return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex()) if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type"); return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly"); return emitError("core_batch weight block argument types must match weight operand types exactly");
} for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType()) if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly"); return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
} }
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
return PimMemCopyOp::create(rewriter, return PimMemCopyOp::create(rewriter,
loc, loc,
@@ -1,9 +1,10 @@
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
auto type = mlir::cast<MemRefType>(memref.getType()); auto type = mlir::cast<MemRefType>(memref.getType());
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
return builder.getI32IntegerAttr(sizeInBytes); return builder.getI32IntegerAttr(sizeInBytes);
} }
@@ -9,6 +9,7 @@
#include <limits> #include <limits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir; using namespace mlir;
@@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) {
} }
static bool isCandidateAllocType(MemRefType type) { static bool isCandidateAllocType(MemRefType type) {
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0; return type && type.hasStaticShape() && type.getLayout().isIdentity()
&& hasByteSizedElementType(type.getElementType());
} }
static uint64_t getTypeSizeBytes(MemRefType type) { static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
} }
static FailureOr<uint64_t> static FailureOr<uint64_t>
@@ -50,10 +52,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
pendingValues.push_back(result); pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) { if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value) if (initArg == value)
pendingValues.push_back(forOp.getResult(index)); pendingValues.push_back(forOp.getResult(index));
}
} }
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) { if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
+18 -6
View File
@@ -43,8 +43,14 @@ def SpatCompute : SpatOp<"compute",
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}]; }];
let hasVerifier = 1; let hasVerifier = 1;
@@ -70,10 +76,16 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument(); std::optional<::mlir::BlockArgument> getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx); std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
::mlir::BlockArgument getOutputArgument(unsigned idx); std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}]; }];
let hasVerifier = 1; let hasVerifier = 1;
+157 -22
View File
@@ -1,16 +1,90 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string> #include <string>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
namespace {
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); } std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
if (body.empty())
return std::nullopt;
BlockArgument SpatCompute::getInputArgument(unsigned idx) { Block& block = body.front();
return getBody().front().getArgument(getWeights().size() + idx); if (argIdx >= block.getNumArguments())
return std::nullopt;
return block.getArgument(argIdx);
}
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
if (body.empty())
return std::nullopt;
return body.insertArgument(argIdx, type, loc);
}
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
if (auto compute = dyn_cast<SpatCompute>(op)) {
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
return;
}
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), idx);
}
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
newCompute->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newCompute.getOperation(),
static_cast<int32_t>(newCompute.getWeights().size()),
static_cast<int32_t>(newCompute.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
} }
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
@@ -18,42 +92,105 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s
return; return;
for (unsigned index = 0; index < getWeights().size(); ++index) for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index) for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
} }
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); } std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); } std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + idx);
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
} }
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) { std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx); return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
}
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
}
std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(idx, ValueRange {weight});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
}
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
unsigned weightCount = getWeights().size();
unsigned inputCount = getInputs().size();
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
setComputeOperandSegmentSizes(
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
if (!blockArg)
return std::nullopt;
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
return failure();
rewriter.setInsertionPoint(getOperation());
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
resultTypes.insert(resultTypes.begin() + idx, type);
auto newBatch =
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
newBatch->setAttrs((*this)->getAttrs());
setComputeOperandSegmentSizes(newBatch.getOperation(),
static_cast<int32_t>(newBatch.getWeights().size()),
static_cast<int32_t>(newBatch.getInputs().size()));
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
if (newBatch.getBody().empty()) {
rewriter.eraseOp(newBatch);
return failure();
}
auto blockArg = newBatch.getBody().front().insertArgument(
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
getResult(oldResultIdx)
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
rewriter.eraseOp(getOperation());
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
} }
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty()) if (region.empty())
return; return;
setNameFn(getLaneArgument(), "lane"); if (auto laneArg = getLaneArgument())
setNameFn(*laneArg, "lane");
for (unsigned index = 0; index < getWeights().size(); ++index) for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); if (auto weightArg = getWeightArgument(index))
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index) for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); if (auto inputArg = getInputArgument(index))
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) { for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
continue;
if (index == 0) { if (index == 0) {
setNameFn(getOutputArgument(index), "out"); setNameFn(*outputArg, "out");
continue; continue;
} }
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str()); setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
} }
} }
@@ -65,9 +202,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
return getRegion().front().getOperations();
}
void SpatialDialect::initialize() { void SpatialDialect::initialize() {
addTypes< addTypes<
+3
View File
@@ -5,12 +5,15 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include <map> #include <map>
#include <optional>
#include <string> #include <string>
#include <tuple>
/// Include the auto-generated header files containing the declarations /// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
+51 -26
View File
@@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return failure(); return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) { switch (currentDelimiter) {
case ListDelimiter::Paren: case ListDelimiter::Paren: return parser.parseRParen();
return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare();
case ListDelimiter::Square:
return parser.parseRSquare();
} }
llvm_unreachable("unsupported delimiter"); llvm_unreachable("unsupported delimiter");
}; };
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) { if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure(); return failure();
}
return success(); return success();
} }
@@ -221,17 +218,26 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
} }
void SpatCompute::print(OpAsmPrinter& printer) { void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
SmallVector<Value> weightArgs; SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size()); weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) for (unsigned index = 0; index < getWeights().size(); ++index) {
weightArgs.push_back(getWeightArgument(index)); auto weightArg = getWeightArgument(index);
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); if (!weightArg)
printer << " "; return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs; SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size()); inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) for (unsigned index = 0; index < getInputs().size(); ++index) {
inputArgs.push_back(getInputArgument(index)); auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
printer << " ";
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
@@ -312,29 +318,48 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
} }
void SpatComputeBatch::print(OpAsmPrinter& printer) { void SpatComputeBatch::print(OpAsmPrinter& printer) {
auto laneArg = getLaneArgument();
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index) {
auto weightArg = getWeightArgument(index);
if (!weightArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
weightArgs.push_back(*weightArg);
}
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index) {
auto inputArg = getInputArgument(index);
if (!inputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
inputArgs.push_back(*inputArg);
}
SmallVector<BlockArgument> outputArgs;
if (!laneArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
if (getNumResults() != 0) {
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index) {
auto outputArg = getOutputArgument(index);
if (!outputArg)
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
outputArgs.push_back(*outputArg);
}
}
printer << " "; printer << " ";
printer.printOperand(getLaneArgument()); printer.printOperand(*laneArg);
printer << " = 0 to " << getLaneCount(); printer << " = 0 to " << getLaneCount();
printer << " "; printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " "; printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) { if (getNumResults() != 0) {
printer << " shared_outs"; printer << " shared_outs";
SmallVector<BlockArgument> outputArgs;
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index)
outputArgs.push_back(getOutputArgument(index));
printBlockArgumentList(printer, outputArgs); printBlockArgumentList(printer, outputArgs);
} }
+41 -19
View File
@@ -107,8 +107,11 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
return false; return false;
unsigned argNumber = blockArg.getArgNumber(); unsigned argNumber = blockArg.getArgNumber();
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber(); auto firstOutputArg = batchOp.getOutputArgument(0);
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults(); if (!firstOutputArg)
return false;
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
} }
static bool isConstantIndexLike(Value value) { static bool isConstantIndexLike(Value value) {
@@ -120,6 +123,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value)) if (value == laneArg || isConstantIndexLike(value))
return true; return true;
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
if (extractOp) {
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
return false;
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
}
auto addOp = value.getDefiningOp<arith::AddIOp>(); auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp) if (!addOp)
return false; return false;
@@ -263,9 +275,9 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
continue; continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError() InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants"; << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() diagnostic.attachNote(op->getLoc())
<< " is used by " << op->getName(); << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true; hasFailure = true;
} }
}); });
@@ -284,10 +296,12 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel"); return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
} }
BlockArgument laneArg = batchOp.getLaneArgument(); auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return batchOp.emitError("compute_batch body must have a lane block argument");
for (auto& bodyOp : block) { for (auto& bodyOp : block) {
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp)) if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice"))) if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
return failure(); return failure();
} }
return success(); return success();
@@ -449,11 +463,13 @@ LogicalResult SpatCompute::verify() {
return emitError("compute body must have weight and input block arguments"); return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType()) auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly"); return emitError("compute weight block argument types must match weight operand types exactly");
} }
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType()) auto blockArg = getInputArgument(inputIndex);
if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly"); return emitError("compute input block argument types must match input operand types exactly");
} }
@@ -490,7 +506,7 @@ LogicalResult SpatCompute::verify() {
} }
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (getInputArgument(inputIndex).use_empty()) if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
return emitError("ComputeOp block argument is not used"); return emitError("ComputeOp block argument is not used");
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure(); return failure();
@@ -567,24 +583,28 @@ LogicalResult SpatComputeBatch::verify() {
} }
Block& block = getBody().front(); Block& block = getBody().front();
if (block.getNumArguments() == 0)
return emitError("compute_batch body must have exactly one lane block argument");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != expectedArgCount) if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body must have lane, weight, input, and output block arguments"); return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
if (!getLaneArgument().getType().isIndex()) auto laneArg = getLaneArgument();
if (!laneArg || !laneArg->getType().isIndex())
return emitError("compute_batch first block argument must have index type"); return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType()) auto blockArg = getWeightArgument(weightIndex);
if (!blockArg || blockArg->getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly"); return emitError("compute_batch weight block argument types must match weight operand types exactly");
} }
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex); auto blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType()) if (!blockArg || blockArg->getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly"); return emitError("compute_batch input block argument types must match input operand types exactly");
} }
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
BlockArgument blockArg = getOutputArgument(resultIndex); auto blockArg = getOutputArgument(resultIndex);
if (blockArg.getType() != resultType) if (!blockArg || blockArg->getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly"); return emitError("compute_batch output block argument types must match result types exactly");
} }
@@ -602,13 +622,15 @@ LogicalResult SpatInParallelOp::verify() {
if (batchOp.getNumResults() == 0) if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent"); return emitOpError("requires a resultful spat.compute_batch parent");
BlockArgument laneArg = batchOp.getLaneArgument(); auto laneArg = batchOp.getLaneArgument();
if (!laneArg)
return emitOpError("expected compute_batch lane block argument");
for (Operation& op : getRegion().front().getOperations()) { for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op); auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp) if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops"); return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice"))) if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
return failure(); return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
@@ -1,6 +1,6 @@
#include "DCPAnalysis.hpp"
#include "../Scheduling/ComputeGraph.hpp" #include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp" #include "../Scheduling/DcpScheduler.hpp"
#include "DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -11,15 +11,15 @@ using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis { struct DCPAnalysis {
private: private:
DCPAnalysisResult result; DCPAnalysisResult result;
mlir::Operation *entryOp; mlir::Operation* entryOp;
DCPAnalysisResult run(); DCPAnalysisResult run();
public: public:
DCPAnalysis(mlir::Operation *op) DCPAnalysis(mlir::Operation* op)
: entryOp(op) { : entryOp(op) {
result = run(); result = run();
} }
DCPAnalysisResult &getResult() { return result; } DCPAnalysisResult& getResult() { return result; }
}; };
} // namespace spatial } // namespace spatial
File diff suppressed because it is too large Load Diff
@@ -10,8 +10,7 @@ namespace spatial {
class MergeScheduleMaterializer { class MergeScheduleMaterializer {
public: public:
mlir::LogicalResult mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
}; };
} // namespace spatial } // namespace spatial
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
class ScopedMergePhaseTimer { class ScopedMergePhaseTimer {
public: public:
explicit ScopedMergePhaseTimer(StringRef phaseName) explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), : enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
phase(phaseName.str()) {
if (enabled) if (enabled)
start = std::chrono::steady_clock::now(); start = std::chrono::steady_clock::now();
} }
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
MergeIrCounts counts = collectMergeIrCounts(funcOp); MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:" llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount << " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount << " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount << " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
<< " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
} }
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
@@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size(); return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
} }
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) { SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices; DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights)) for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
targetWeightIndices[weight].push_back(weightIndex); targetWeightIndices[weight].push_back(weightIndex);
@@ -226,18 +223,32 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
newBody->addArgument(input.getType(), loc); newBody->addArgument(input.getType(), loc);
IRMapping mapper; IRMapping mapper;
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) {
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex)); auto oldWeightArg = compute.getWeightArgument(weightIndex);
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) auto newWeightArg = newCompute.getWeightArgument(weightIndex);
mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex)); assert(oldWeightArg && newWeightArg && "expected compute weight block arguments");
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) mapper.map(*oldWeightArg, *newWeightArg);
mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex])); }
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) {
auto oldInputArg = compute.getInputArgument(inputIndex);
auto newInputArg = newCompute.getInputArgument(inputIndex);
assert(oldInputArg && newInputArg && "expected compute input block arguments");
mapper.map(*oldInputArg, *newInputArg);
}
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) {
auto oldWeightArg = child.getWeightArgument(oldIndex);
auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]);
assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments");
mapper.map(*oldWeightArg, *newWeightArg);
}
rewriter.setInsertionPointToEnd(newBody); rewriter.setInsertionPointToEnd(newBody);
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator()); auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
for (Operation& op : compute.getBody().front().without_terminator()) for (Operation& op : compute.getBody().front().without_terminator())
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult))); auto childInputArg = child.getInputArgument(childInputIndex);
assert(childInputArg && "expected child compute input block argument");
mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
rewriter.setInsertionPointToEnd(newBody); rewriter.setInsertionPointToEnd(newBody);
for (auto& op : child.getBody().front()) for (auto& op : child.getBody().front())
@@ -649,12 +660,12 @@ public:
emitMergeIrCounts("after-materialization", func); emitMergeIrCounts("after-materialization", func);
if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) { /*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
emitMergeIrCounts("after-post-merge-compaction", func); emitMergeIrCounts("after-post-merge-compaction", func);*/
{ {
ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); ScopedMergePhaseTimer timer("cleanup-topological-sort-report");
@@ -267,212 +267,6 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
} }
struct BatchYieldInfo {
Value yieldedValue;
tensor::ParallelInsertSliceOp insertSlice;
};
static bool isHostOnlyBatchResultUser(Operation* user) {
return isa<func::ReturnOp,
spatial::SpatConcatOp,
tensor::ExtractSliceOp,
tensor::CastOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp>(user);
}
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
Block& block = batchOp.getBody().front();
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
if (!inParallel)
return failure();
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
for (Operation& op : inParallel.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSlice)
return failure();
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &block)
return failure();
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
}
return batchYieldByOutputArg;
}
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return failure();
Block& oldBlock = batchOp.getBody().front();
rewriter.setInsertionPoint(batchOp);
auto newBatch = SpatComputeBatch::create(rewriter,
batchOp.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
batchOp.getWeights(),
batchOp.getInputs());
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
}
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
}
Block* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
for (Operation& op : oldBlock.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
}
return newBatch;
}
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batchOp : batches) {
if (batchOp.getNumResults() == 0)
continue;
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
if (failed(batchYieldInfo))
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
if (failed(newBatch))
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
Block& oldBlock = batchOp.getBody().front();
Block& newBlock = newBatch->getBody().front();
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
auto oldIt = oldBlock.begin();
auto newIt = newBlock.begin();
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
mapper.map(oldResult, newResult);
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
rewriter.setInsertionPointToEnd(&newBlock);
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
auto yieldInfoIt = batchYieldInfo->find(outputArg);
if (yieldInfoIt == batchYieldInfo->end())
return batchOp.emitOpError(
"missing yielded value for compute_batch result during communication materialization");
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
SmallVector<OpOperand*> hostUses;
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
continue;
}
if (isHostOnlyBatchResultUser(use.getOwner())) {
hostUses.push_back(&use);
continue;
}
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
<< ": " << use.getOwner()->getName();
}
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
if (uses.empty())
return success();
SmallVector<int64_t> channelIds;
channelIds.reserve(sourceCoreIds.size());
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
channelIds.push_back(nextChannelId++);
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
batchOp.getLoc(),
sendChannelIdValues,
sendSourceCoreIdValues,
sendTargetCoreIdValues,
mappedYieldedValue);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch->getOperation());
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
batchOp.getLoc(),
batchOp.getResult(resultIndex).getType(),
receiveChannelIdValues,
receiveSourceCoreIdValues,
receiveTargetCoreIdValues);
for (OpOperand* use : uses)
use->set(received.getOutput());
rewriter.setInsertionPointToEnd(&newBlock);
return success();
};
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
if (failed(createReceiveForUses(uses, targetCoreIds)))
return failure();
}
if (!hostUses.empty()) {
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
return failure();
}
}
rewriter.setInsertionPointToEnd(&newBlock);
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
rewriter.eraseOp(batchOp);
}
return success();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) { void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
@@ -731,11 +525,6 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp); cleanupDeadPackingOps(funcOp);
} }
{
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
return failure();
}
return success(); return success();
} }
@@ -7,6 +7,6 @@
namespace onnx_mlir { namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId); mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); SmallVector<Value> sourceCoreIdValues =
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues =
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
spatial::SpatChannelSendTensorOp::create( spatial::SpatChannelSendTensorOp::create(
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput); rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
for (auto op : run.ops) for (auto op : run.ops)
@@ -7,7 +7,7 @@
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <algorithm> #include <algorithm>
#include <limits> #include <iterator>
#include <optional> #include <optional>
#include <queue> #include <queue>
#include <utility> #include <utility>
@@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
return true; return true;
} }
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
auto offsetValue = llvm::dyn_cast<Value>(offset);
return offsetValue == laneArg;
}
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
auto inputIt = llvm::find(batch.getInputs(), input);
if (inputIt == batch.getInputs().end())
return std::nullopt;
size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt);
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
if (!inputArg || !laneArg)
return std::nullopt;
Weight projectedCost = 0;
for (Operation* user : inputArg->getUsers()) {
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extract || extract.getSource() != *inputArg)
return std::nullopt;
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
return std::nullopt;
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return std::nullopt;
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
}
if (projectedCost == 0)
return std::nullopt;
return projectedCost;
}
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
auto inputType = cast<ShapedType>(input.getType());
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
return *projectedCost;
return static_cast<Weight>(getSizeInBytes(inputType));
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) { std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights; llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge& edge : edges) { for (const ComputeGraphEdge& edge : edges) {
@@ -136,15 +179,16 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges; llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) { for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) { llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
for (Value input : inputs) {
Weight transferCost = getInputTransferCost(node.instance, input);
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp()); if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) { producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) { for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane)); auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
if (producerIt == graph.instanceToIndex.end()) if (producerIt == graph.instanceToIndex.end())
continue; continue;
rawEdges.push_back( rawEdges.push_back({producerIt->second, targetIndex, transferCost});
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
} }
continue; continue;
} }
@@ -155,8 +199,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
auto producerIt = graph.instanceToIndex.find(*producerInstance); auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end()) if (producerIt == graph.instanceToIndex.end())
continue; continue;
rawEdges.push_back( rawEdges.push_back({producerIt->second, targetIndex, transferCost});
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
} }
} }
@@ -39,11 +39,11 @@ struct ComputeGraph {
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex; llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
}; };
ComputeGraph buildComputeGraph(mlir::Operation *entryOp); ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
bool verifyAcyclic(const ComputeGraph &graph); bool verifyAcyclic(const ComputeGraph& graph);
Weight getComputeInstanceWeight(const ComputeInstance &instance); Weight getComputeInstanceWeight(const ComputeInstance& instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance); CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -11,11 +11,11 @@ namespace onnx_mlir {
namespace spatial { namespace spatial {
struct ComputeInstance { struct ComputeInstance {
mlir::Operation *op = nullptr; mlir::Operation* op = nullptr;
uint32_t laneStart = 0; uint32_t laneStart = 0;
uint32_t laneCount = 1; uint32_t laneCount = 1;
bool operator==(const ComputeInstance &other) const { bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount; return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
} }
}; };
@@ -29,16 +29,15 @@ namespace llvm {
template <> template <>
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> { struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static onnx_mlir::spatial::ComputeInstance getEmptyKey() { static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX}; return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
} }
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() { static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
} }
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) { static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount); return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
} }
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs, static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
const onnx_mlir::spatial::ComputeInstance &rhs) {
return lhs == rhs; return lhs == rhs;
} }
}; };
@@ -27,15 +27,15 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex)
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value, std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr); const ComputeInstance* consumerInstance = nullptr);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value, std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr); const ComputeInstance* consumerInstance = nullptr);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance); llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance& instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance); llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance& instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance); llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance);
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance); llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance);
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance); mlir::Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -10,8 +10,8 @@
#include <queue> #include <queue>
#include <vector> #include <vector>
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp" #include "../DCPGraph/Graph.hpp"
#include "DcpScheduler.hpp"
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
@@ -47,7 +47,7 @@ struct WindowScheduleResult {
size_t maxMergeGroupSize = 0; size_t maxMergeGroupSize = 0;
}; };
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) { size_t getSchedulingCpuBudget(const DcpScheduleOptions& options) {
if (options.processorCount > 0) if (options.processorCount > 0)
return options.processorCount; return options.processorCount;
return std::numeric_limits<size_t>::max(); return std::numeric_limits<size_t>::max();
@@ -72,7 +72,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
for (auto [key, weight] : edgeWeights) for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back( aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)}); {static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) { llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs)) if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs); return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs); return std::get<1>(lhs) < std::get<1>(rhs);
@@ -80,7 +80,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
return aggregatedEdges; return aggregatedEdges;
} }
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) { VirtualGraph buildInitialVirtualGraph(const ComputeGraph& graph) {
VirtualGraph virtualGraph; VirtualGraph virtualGraph;
virtualGraph.nodes.reserve(graph.nodes.size()); virtualGraph.nodes.reserve(graph.nodes.size());
for (auto [index, node] : llvm::enumerate(graph.nodes)) { for (auto [index, node] : llvm::enumerate(graph.nodes)) {
@@ -93,14 +93,14 @@ VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
std::vector<IndexedEdge> edges; std::vector<IndexedEdge> edges;
edges.reserve(graph.edges.size()); edges.reserve(graph.edges.size());
for (const ComputeGraphEdge &edge : graph.edges) for (const ComputeGraphEdge& edge : graph.edges)
edges.push_back( edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)}); {static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
virtualGraph.edges = aggregateEdges(edges); virtualGraph.edges = aggregateEdges(edges);
return virtualGraph; return virtualGraph;
} }
TimingInfo computeTiming(const VirtualGraph &graph) { TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing; TimingInfo timing;
size_t nodeCount = graph.nodes.size(); size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0); timing.aest.assign(nodeCount, 0);
@@ -122,7 +122,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
} }
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) { auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode &node = graph.nodes[nodeIndex]; const VirtualNode& node = graph.nodes[nodeIndex];
if (!node.originalNodeIndices.empty()) if (!node.originalNodeIndices.empty())
return node.originalNodeIndices.front(); return node.originalNodeIndices.front();
return nodeIndex; return nodeIndex;
@@ -181,7 +181,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
return timing; return timing;
} }
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) { std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size()); std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) { for (auto [start, end, weight] : graph.edges) {
(void) weight; (void) weight;
@@ -191,14 +191,14 @@ std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &gr
adjacency[startIndex].push_back(endIndex); adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex); adjacency[endIndex].push_back(startIndex);
} }
for (auto &neighbours : adjacency) { for (auto& neighbours : adjacency) {
llvm::sort(neighbours); llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end()); neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
} }
return adjacency; return adjacency;
} }
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) { std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size()); std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0); std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) { auto isHigherPriority = [&](size_t lhs, size_t rhs) {
@@ -240,7 +240,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); }; auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare); std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) { auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node]) if (inWindow[node])
return; return;
inWindow[node] = true; inWindow[node] = true;
@@ -288,7 +288,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
return selected; return selected;
} }
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) { std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges; std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size()); windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) { for (auto [start, end, weight] : graph.edges) {
@@ -301,10 +301,10 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::
return aggregateEdges(windowEdges); return aggregateEdges(windowEdges);
} }
WindowScheduleResult scheduleWindow(const VirtualGraph &graph, WindowScheduleResult scheduleWindow(const VirtualGraph& graph,
llvm::ArrayRef<size_t> selectedNodes, llvm::ArrayRef<size_t> selectedNodes,
const DcpScheduleOptions &options, const DcpScheduleOptions& options,
mlir::MLIRContext *context) { mlir::MLIRContext* context) {
std::vector<Weight> windowWeights; std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage; std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys; std::vector<int64_t> windowNodeOrderKeys;
@@ -338,17 +338,17 @@ WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size()); result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup; std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size()); mergeGroup.reserve(scheduledTasks.size());
for (const auto &task : scheduledTasks) for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]); mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup)); result.mergeGroups.push_back(std::move(mergeGroup));
} }
return result; return result;
} }
bool coarsenGraph(const VirtualGraph &graph, bool coarsenGraph(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups, llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph &coarsenedGraph, VirtualGraph& coarsenedGraph,
std::vector<size_t> &oldToNewNode) { std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph); TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size()); std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0); std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
@@ -358,7 +358,7 @@ bool coarsenGraph(const VirtualGraph &graph,
std::vector<std::vector<size_t>> orderedMergeGroups; std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size()); orderedMergeGroups.reserve(mergeGroups.size());
for (const auto &mergeGroup : mergeGroups) { for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end()); orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) { std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs]) if (topologicalRank[lhs] != topologicalRank[rhs])
@@ -395,7 +395,7 @@ bool coarsenGraph(const VirtualGraph &graph,
continue; continue;
} }
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)]; auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) { if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex; oldToNewNode[nodeIndex] = *newNodeIndex;
continue; continue;
@@ -403,8 +403,9 @@ bool coarsenGraph(const VirtualGraph &graph,
VirtualNode mergedNode; VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) { for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode &memberNode = graph.nodes[memberIndex]; const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end()); mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
memberNode.originalNodeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight); mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage); mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
} }
@@ -437,7 +438,7 @@ bool coarsenGraph(const VirtualGraph &graph,
return true; return true;
} }
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) { size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions& options) {
size_t windowSize = std::min(options.criticalWindowSize, nodeCount); size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options))); CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
if (nodeCount > static_cast<size_t>(maxCpuCount)) if (nodeCount > static_cast<size_t>(maxCpuCount))
@@ -445,7 +446,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &op
return windowSize; return windowSize;
} }
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) { void assignFeasibleAest(const ComputeGraph& graph, MergeScheduleResult& result) {
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance; llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
nodeIndexByInstance.reserve(graph.nodes.size()); nodeIndexByInstance.reserve(graph.nodes.size());
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes)) for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
@@ -458,7 +459,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size()); std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0); std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
for (const ComputeGraphEdge &edge : graph.edges) { for (const ComputeGraphEdge& edge : graph.edges) {
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance; const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
const ComputeInstance targetInstance = graph.nodes[edge.target].instance; const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance); const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
@@ -473,15 +474,15 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
} }
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu; llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
for (const ComputeGraphNode &node : graph.nodes) { for (const ComputeGraphNode& node : graph.nodes) {
size_t cpu = result.computeToCpuMap.lookup(node.instance); size_t cpu = result.computeToCpuMap.lookup(node.instance);
size_t slot = result.computeToCpuSlotMap.lookup(node.instance); size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)}); tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
} }
for (auto &entry : tasksByCpu) { for (auto& entry : tasksByCpu) {
auto &scheduledTasks = entry.second; auto& scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) { llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first) if (lhs.first != rhs.first)
return lhs.first < rhs.first; return lhs.first < rhs.first;
return lhs.second < rhs.second; return lhs.second < rhs.second;
@@ -512,7 +513,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
readyNodes.pop(); readyNodes.pop();
processedNodeCount++; processedNodeCount++;
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) { for (const ScheduledEdge& edge : scheduledChildren[sourceIndex]) {
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay)); startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow"); assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
incomingEdgeCount[edge.target]--; incomingEdgeCount[edge.target]--;
@@ -528,7 +529,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
result.computeToAestMap[node.instance] = startTimes[nodeIndex]; result.computeToAestMap[node.instance] = startTimes[nodeIndex];
} }
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) { MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph& graph, const ComputeGraph& originalGraph) {
MergeScheduleResult result; MergeScheduleResult result;
TimingInfo timing = computeTiming(graph); TimingInfo timing = computeTiming(graph);
@@ -542,7 +543,7 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0); std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) { for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex]; const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalNodeIndices) for (size_t originalIndex : virtualNode.originalNodeIndices)
originalNodeToCpu[originalIndex] = cpu; originalNodeToCpu[originalIndex] = cpu;
} }
@@ -556,17 +557,17 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++; result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
result.cpuToLastComputeMap[cpu] = node.instance; result.cpuToLastComputeMap[cpu] = node.instance;
} }
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap) for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute); result.isLastComputeOfCpu.insert(lastCompute);
assignFeasibleAest(originalGraph, result); assignFeasibleAest(originalGraph, result);
return result; return result;
} }
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) { MergeScheduleResult buildResultFromScheduledGraph(GraphDCP& graphDCP, const ComputeGraph& graph) {
MergeScheduleResult result; MergeScheduleResult result;
result.dominanceOrderCompute.reserve(graph.nodes.size()); result.dominanceOrderCompute.reserve(graph.nodes.size());
for (const ComputeGraphNode &node : graph.nodes) for (const ComputeGraphNode& node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance); result.dominanceOrderCompute.push_back(node.instance);
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) { for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
return result; return result;
} }
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) { MergeScheduleResult
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
llvm::SmallVector<Weight> nodeWeights; llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage; llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys; llvm::SmallVector<int64_t> nodeOrderKeys;
@@ -599,12 +601,12 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
nodeOrderKeys.reserve(graph.nodes.size()); nodeOrderKeys.reserve(graph.nodes.size());
edges.reserve(graph.edges.size()); edges.reserve(graph.edges.size());
for (const ComputeGraphNode &node : graph.nodes) { for (const ComputeGraphNode& node : graph.nodes) {
nodeWeights.push_back(node.weight); nodeWeights.push_back(node.weight);
nodeCrossbarUsage.push_back(node.crossbarUsage); nodeCrossbarUsage.push_back(node.crossbarUsage);
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder)); nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
} }
for (const ComputeGraphEdge &edge : graph.edges) { for (const ComputeGraphEdge& edge : graph.edges) {
edges.push_back( edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)}); {static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
} }
@@ -617,11 +619,11 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
return buildResultFromScheduledGraph(graphDCP, graph); return buildResultFromScheduledGraph(graphDCP, graph);
} }
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) { bool needsExactScheduledBatches(const ComputeGraph& graph, const DcpScheduleOptions& options) {
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount) if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
return false; return false;
size_t schedulingCpuBudget = getSchedulingCpuBudget(options); size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) { return llvm::any_of(graph.nodes, [&](const ComputeGraphNode& node) {
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op); auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget; return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
}); });
@@ -630,7 +632,7 @@ bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOpti
} // namespace } // namespace
MergeScheduleResult MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) { runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
if (needsExactScheduledBatches(graph, options)) if (needsExactScheduledBatches(graph, options))
return runLegacyDcp(graph, options, context); return runLegacyDcp(graph, options, context);
@@ -15,7 +15,7 @@ struct DcpScheduleOptions {
}; };
MergeScheduleResult MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context); runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap; llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu; llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap; llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
}; };
} // namespace spatial } // namespace spatial
@@ -1,13 +1,13 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp" #include "../DCPGraph/DCPAnalysis.hpp"
#include "ComputeGraph.hpp"
#include "DcpScheduler.hpp" #include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp" #include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp" #include "PeftScheduler.hpp"
@@ -20,15 +20,13 @@ namespace {
MergeSchedulerKind getSchedulerKind() { MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) { switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft: case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
return MergeSchedulerKind::Peft; case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
} }
llvm_unreachable("unknown merge scheduler kind"); llvm_unreachable("unknown merge scheduler kind");
} }
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) { void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, CrossbarUsage crossbarCapacity) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu; llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size()); tasksByCpu.reserve(result.cpuToLastComputeMap.size());
@@ -45,9 +43,9 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
{result.computeToCpuSlotMap.lookup(instance), nodeIndex}); {result.computeToCpuSlotMap.lookup(instance), nodeIndex});
} }
for (auto &entry : tasksByCpu) { for (auto& entry : tasksByCpu) {
auto &scheduledTasks = entry.second; auto& scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) { llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first) if (lhs.first != rhs.first)
return lhs.first < rhs.first; return lhs.first < rhs.first;
return lhs.second < rhs.second; return lhs.second < rhs.second;
@@ -70,7 +68,7 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
llvm::report_fatal_error("merge scheduling: missing last-compute marker"); llvm::report_fatal_error("merge scheduling: missing last-compute marker");
} }
for (const ComputeGraphEdge &edge : graph.edges) { for (const ComputeGraphEdge& edge : graph.edges) {
const ComputeInstance source = graph.nodes[edge.source].instance; const ComputeInstance source = graph.nodes[edge.source].instance;
const ComputeInstance target = graph.nodes[edge.target].instance; const ComputeInstance target = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(source); const size_t sourceCpu = result.computeToCpuMap.lookup(source);
@@ -97,8 +95,8 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
} // namespace } // namespace
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op) MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation* op)
: entryOp(op) { : entryOp(op) {
result = run(); result = run();
} }
@@ -115,20 +113,17 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
MergeScheduleResult schedule; MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) { if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler( schedule = runPeftScheduler(graph,
graph, PeftScheduleOptions {options.processorCount,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()), static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()}); entryOp->getContext()});
} }
else { else {
schedule = runDcpScheduler( schedule = runDcpScheduler(graph,
graph, DcpScheduleOptions {options.processorCount,
DcpScheduleOptions { dcpCriticalWindowSize.getValue(),
options.processorCount, options.allowDcpFallbackForAutoCoreCount},
dcpCriticalWindowSize.getValue(), entryOp->getContext());
options.allowDcpFallbackForAutoCoreCount
},
entryOp->getContext());
} }
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue())); verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
@@ -22,11 +22,11 @@ struct MergeSchedulingOptions {
class MergeSchedulingAnalysis { class MergeSchedulingAnalysis {
public: public:
explicit MergeSchedulingAnalysis(mlir::Operation *op); explicit MergeSchedulingAnalysis(mlir::Operation* op);
MergeScheduleResult &getResult() { return result; } MergeScheduleResult& getResult() { return result; }
private: private:
mlir::Operation *entryOp = nullptr; mlir::Operation* entryOp = nullptr;
MergeScheduleResult result; MergeScheduleResult result;
MergeScheduleResult run(); MergeScheduleResult run();
@@ -19,7 +19,6 @@ struct ScheduledTask {
size_t processor = std::numeric_limits<size_t>::max(); size_t processor = std::numeric_limits<size_t>::max();
Time startTime = 0; Time startTime = 0;
Time endTime = 0; Time endTime = 0;
size_t slot = 0;
}; };
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) { std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
@@ -243,7 +242,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
llvm::report_fatal_error(llvm::StringRef(message)); llvm::report_fatal_error(llvm::StringRef(message));
} }
schedules[task] = {bestProcessor, bestEst, bestEft, 0}; schedules[task] = {bestProcessor, bestEst, bestEft};
scheduled[task] = true; scheduled[task] = true;
++scheduledCount; ++scheduledCount;
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage); processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
@@ -274,7 +273,65 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder; return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
}); });
// 5. Populate Final Result // 5. Check if equal schedule in two level
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
for (size_t controlProcessor = currentProcessor; controlProcessor < processorCount; ++controlProcessor) {
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
continue;
auto& currentTasks = tasksByProcessor[currentProcessor];
auto& controlTasks = tasksByProcessor[controlProcessor];
bool equalSchedule = true;
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
if (currentComputeInstance.op != controlComputeInstance.op
|| currentComputeInstance.laneCount != controlComputeInstance.laneCount) {
equalSchedule = false;
break;
}
}
if (equalSchedule) {
equivalentClass[currentProcessor].push_back(controlProcessor);
equivalentClass[controlProcessor].push_back(currentProcessor);
}
}
}
/*{
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
std::vector<bool> visited(processorCount, false);
size_t uniqueClassCount = 0;
for (size_t i = 0; i < processorCount; ++i) {
if (visited[i])
continue;
// We found a new unique schedule (equivalence class)
++uniqueClassCount;
visited[i] = true;
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
// Find and mark all identical companions
auto it = equivalentClass.find(i);
if (it != equivalentClass.end()) {
for (size_t eqCpu : it->second) {
if (!visited[eqCpu]) {
llvm::dbgs() << ", " << eqCpu;
visited[eqCpu] = true;
}
}
}
llvm::dbgs() << " }\n";
}
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
llvm::dbgs() << "--------------------------------------\n";
}*/
// 6. Populate Final Result
MergeScheduleResult result; MergeScheduleResult result;
result.dominanceOrderCompute.reserve(nodeCount); result.dominanceOrderCompute.reserve(nodeCount);
@@ -296,8 +353,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
} }
} }
result.equivalentClass = equivalentClass;
return result; return result;
} }
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -11,10 +11,10 @@ namespace spatial {
struct PeftScheduleOptions { struct PeftScheduleOptions {
size_t processorCount = 0; size_t processorCount = 0;
CrossbarUsage crossbarCapacity = 0; CrossbarUsage crossbarCapacity = 0;
mlir::MLIRContext *context = nullptr; mlir::MLIRContext* context = nullptr;
}; };
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options); MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -120,7 +120,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; auto sizeInBytes = getShapedTypeSizeInBytes(initType);
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(), mapOp.getLoc(),
initType, initType,
@@ -176,9 +176,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes)) if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure(); return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; if (!hasByteSizedElementType(sourceType.getElementType()))
if (elementByteWidth <= 0)
return failure(); return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes) if (size != totalBytes)
@@ -31,13 +31,6 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
return false; return false;
} }
static int64_t getValueSizeInBytes(Value value) {
auto type = dyn_cast<ShapedType>(value.getType());
if (!type || !type.hasStaticShape())
return -1;
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
template <typename CoreOpTy> template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter, IRRewriter& rewriter,
@@ -82,7 +75,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
continue; continue;
} }
int64_t totalBytes = getValueSizeInBytes(originalValue); int64_t totalBytes = -1;
if (auto type = dyn_cast<ShapedType>(originalValue.getType()); type && type.hasStaticShape())
totalBytes = static_cast<int64_t>(getShapedTypeSizeInBytes(type));
if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) {
op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets");
hasFailure = true; hasFailure = true;
+4 -3
View File
@@ -8,8 +8,8 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) { if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
(void) withScalarCoreFromBatchLane( (void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); }); return verifyCoreOperands(scalarCore, diagnostics);
});
continue; continue;
} }
+1 -1
View File
@@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------| |---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights | | Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N | | Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector | | With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight | | transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
+13
View File
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
# GEMM tests # GEMM tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def gemm_simple():
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
B, K, N = 10, 132, 132
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/simple", "gemm_simple.onnx")
def gemm_non_square(): def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N.""" """GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64 B, K, N = 4, 128, 64
@@ -823,6 +835,7 @@ def div_after_gemm():
if __name__ == "__main__": if __name__ == "__main__":
print("Generating GEMM tests:") print("Generating GEMM tests:")
gemm_simple()
gemm_non_square() gemm_non_square()
gemm_with_bias() gemm_with_bias()
gemm_transB() gemm_transB()