Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d136136d22 | |||
| 074eb183c7 | |||
| 43ed3914b8 | |||
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 |
+92
-24
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
||||
|
||||
project(raptor)
|
||||
|
||||
# Add symlink to PIM as accelerator in onnx-mlir
|
||||
function(raptor_ensure_symlink link_path target_path)
|
||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
||||
# Materialize a CMake shim directory
|
||||
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||
|
||||
if(NOT EXISTS "${link_parent}")
|
||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS "${link_path}")
|
||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
||||
file(CREATE_LINK
|
||||
"${target_path}"
|
||||
"${link_path}"
|
||||
SYMBOLIC
|
||||
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||
message(FATAL_ERROR
|
||||
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||
" ${real_external_source_dir}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (IS_SYMLINK "${shim_dir}")
|
||||
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||
file(REMOVE "${shim_dir}")
|
||||
endif ()
|
||||
|
||||
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||
endif ()
|
||||
|
||||
file(MAKE_DIRECTORY "${shim_dir}")
|
||||
|
||||
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||
set(shim_contents
|
||||
"get_filename_component(raptor_external_source_dir
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
"
|
||||
)
|
||||
|
||||
if (EXISTS "${shim_file}")
|
||||
file(READ "${shim_file}" old_contents)
|
||||
else ()
|
||||
set(old_contents "")
|
||||
endif ()
|
||||
|
||||
if (NOT old_contents STREQUAL shim_contents)
|
||||
file(WRITE "${shim_file}" "${shim_contents}")
|
||||
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||
else ()
|
||||
message(STATUS "CMake shim already up to date for ${description}")
|
||||
endif ()
|
||||
|
||||
# Mirror the external tree's first-level entries into the shim directory
|
||||
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||
|
||||
foreach (child IN LISTS children)
|
||||
if (child STREQUAL "CMakeLists.txt")
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
set(real_child "${real_external_source_dir}/${child}")
|
||||
set(shim_child "${shim_dir}/${child}")
|
||||
|
||||
if (IS_SYMLINK "${shim_child}")
|
||||
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||
if (existing_link_target STREQUAL real_child)
|
||||
continue()
|
||||
endif ()
|
||||
file(REMOVE_RECURSE "${shim_child}")
|
||||
elseif (EXISTS "${shim_child}")
|
||||
# Do not delete real files/directories. This protects the generated shim.
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
file(CREATE_LINK
|
||||
"${real_child}"
|
||||
"${shim_child}"
|
||||
SYMBOLIC
|
||||
)
|
||||
endforeach ()
|
||||
endfunction()
|
||||
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
"PIM accelerator"
|
||||
)
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
"PIM accelerator tests"
|
||||
)
|
||||
|
||||
# Patch onnx-mlir sources for PIM accelerator support.
|
||||
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
|
||||
|
||||
# Already applied – replacement text is present
|
||||
string(FIND "${contents}" "${replacement}" already_applied_pos)
|
||||
if(NOT already_applied_pos EQUAL -1)
|
||||
if (NOT already_applied_pos EQUAL -1)
|
||||
message(STATUS "Patch already applied: ${description}")
|
||||
return()
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Anchor must exist for the patch to be applicable
|
||||
string(FIND "${contents}" "${anchor}" anchor_pos)
|
||||
if(anchor_pos EQUAL -1)
|
||||
if (anchor_pos EQUAL -1)
|
||||
message(FATAL_ERROR
|
||||
"Patch anchor not found – onnx-mlir may have changed.\n"
|
||||
" Patch : ${description}\n"
|
||||
" File : ${file_path}\n"
|
||||
" Anchor: ${anchor}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
|
||||
file(WRITE "${file_path}" "${patched}")
|
||||
|
||||
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
||||
if in_path.contains(&waiting_for) {
|
||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||
let cycle = &path[cycle_start..];
|
||||
let format_core = |core: &i32| (core - 1).to_string();
|
||||
|
||||
let cycle_str = cycle
|
||||
.iter()
|
||||
.map(|c| c.to_string())
|
||||
.map(format_core)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core, size, target)
|
||||
format!("core {} send {}B -> {}", core - 1, size, target - 1)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core, size, source)
|
||||
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core),
|
||||
CoreState::Halted => format!("core {} halted", core),
|
||||
CoreState::Working => format!("core {} working", core - 1),
|
||||
CoreState::Halted => format!("core {} halted", core - 1),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
|
||||
@@ -28,23 +28,47 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
|
||||
if (Value mapped = mapper.lookupOrNull(value))
|
||||
return mapped;
|
||||
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
|
||||
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
|
||||
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
assert(definingOp && "expected captured value to be defined by an operation");
|
||||
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
|
||||
|
||||
for (Value operand : definingOp->getOperands())
|
||||
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||
|
||||
Operation* cloned = builder.clone(*definingOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
OperationFolder& constantFolder) {
|
||||
Block& oldBlock = coreBatchOp.getBody().front();
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightCount = coreBatchOp.getWeights().size();
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (blockArg.getType().isIndex()) {
|
||||
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
|
||||
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (argIndex <= weightCount) {
|
||||
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
|
||||
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
|
||||
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -57,8 +81,10 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (Value operand : op.getOperands())
|
||||
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
pim::PimSendOp::create(
|
||||
builder,
|
||||
sendBatchOp.getLoc(),
|
||||
@@ -78,7 +104,6 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
auto scalarReceive = pim::PimReceiveOp::create(
|
||||
builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
@@ -106,8 +131,8 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
@@ -141,7 +166,16 @@ LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
|
||||
auto scalarCore =
|
||||
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Location> weightLocs;
|
||||
weightTypes.reserve(weights.size());
|
||||
weightLocs.reserve(weights.size());
|
||||
for (Value weight : weights) {
|
||||
weightTypes.push_back(weight.getType());
|
||||
weightLocs.push_back(weight.getLoc());
|
||||
}
|
||||
Block* block =
|
||||
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (unsigned lane : lanes)
|
||||
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -97,22 +99,75 @@ static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveT
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
|
||||
if (!returnOp)
|
||||
return failure();
|
||||
return result.getUses().begin()->getOperandNumber();
|
||||
}
|
||||
|
||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||
if (scale == 1)
|
||||
return base;
|
||||
|
||||
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
|
||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
ShapedType destinationType,
|
||||
IRMapping& mapper) {
|
||||
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
|
||||
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
||||
ArrayRef<int64_t> shape = destinationType.getShape();
|
||||
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
|
||||
Value totalOffset;
|
||||
Location loc = insertSlice.getLoc();
|
||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
||||
int64_t scale = strides[dim] * elementBytes;
|
||||
Value scaledOffset;
|
||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(intAttr && "expected integer offset attribute");
|
||||
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
|
||||
}
|
||||
else {
|
||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||
}
|
||||
|
||||
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
|
||||
: scaledOffset;
|
||||
}
|
||||
|
||||
if (!totalOffset)
|
||||
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
return totalOffset;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
|
||||
"materialize explicit communication before lowering to PIM");
|
||||
|
||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
|
||||
if (computeBatchOp.getNumResults() == 0) {
|
||||
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||
}
|
||||
else if (!inParallelOp) {
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
@@ -128,9 +183,24 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Value> hostOutputTensors;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
hostOutputTensors.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
|
||||
hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc);
|
||||
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : oldBlock.getArguments()) {
|
||||
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
|
||||
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
@@ -183,6 +253,38 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber();
|
||||
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||
if (!insertSlice)
|
||||
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
|
||||
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
|
||||
if (resultIndex >= hostOutputTensors.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
auto hostTarget = hostOutputTensors[resultIndex];
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
hostTarget.getType(),
|
||||
hostTargetOffset,
|
||||
zeroOffset,
|
||||
hostTarget,
|
||||
mappedSource,
|
||||
getTensorSizeInBytesAttr(rewriter, mappedSource));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
BatchCoreLoweringPatterns.cpp
|
||||
ChannelLoweringPatterns.cpp
|
||||
Cleanup.cpp
|
||||
Common.cpp
|
||||
ComputeLikeRegionUtils.cpp
|
||||
CoreLoweringPatterns.cpp
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
||||
while (!pendingOps.empty()) {
|
||||
bool erasedAnyOp = false;
|
||||
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
||||
Operation* opToRemove = *it;
|
||||
if (!opToRemove->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(opToRemove);
|
||||
it = pendingOps.erase(it);
|
||||
erasedAnyOp = true;
|
||||
}
|
||||
|
||||
if (erasedAnyOp)
|
||||
continue;
|
||||
|
||||
for (Operation* opToRemove : pendingOps) {
|
||||
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
||||
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
||||
for (Operation* user : opToRemove->getUsers()) {
|
||||
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
||||
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
||||
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,11 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -6,9 +6,9 @@
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -148,15 +148,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
} // namespace
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -179,7 +176,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -200,7 +197,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,9 +208,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||
ReturnPathLoweringResult returnPathResult =
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||
return failure();
|
||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||
@@ -240,7 +236,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
@@ -249,7 +245,7 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -261,8 +257,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
mlir::OperationFolder& constantFolder;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -141,152 +141,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
};
|
||||
|
||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = constantOp.getLoc();
|
||||
|
||||
if (hasWeightAlways(constantOp))
|
||||
return failure();
|
||||
|
||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
return false;
|
||||
if (isa<func::FuncOp>(op->getParentOp()))
|
||||
return true;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||
|
||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||
|
||||
if (constRankedTensorType) {
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||
loc,
|
||||
constantOp->getParentOfType<ModuleOp>(),
|
||||
"const",
|
||||
memRefType,
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr());
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||
spatComputeBatch.getOperation(),
|
||||
BBArgIndex,
|
||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
Value hostConstant = constantOp.getResult();
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (constantOp->use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -363,7 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -44,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
@@ -390,9 +385,7 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
|
||||
} // namespace
|
||||
|
||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Value currentReturnValue = returnValue;
|
||||
@@ -427,8 +420,8 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||
Location loc = producerOp->getLoc();
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
@@ -437,13 +430,13 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -464,7 +457,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -480,11 +473,11 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
@@ -505,7 +498,7 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
@@ -553,12 +546,15 @@ ReturnPathLoweringResult lowerProducedValueReturnPath(
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult
|
||||
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
OpResult result,
|
||||
Value yieldValue,
|
||||
IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
return;
|
||||
@@ -575,13 +571,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
|
||||
if (isReturnHelperChainOp(op)) {
|
||||
Value source = op->getOperand(0);
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(state, computeOp);
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
@@ -589,33 +585,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||
markOpToRemove(state, receiveOp);
|
||||
markOpToRemove(receiveOp);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(state, receiveTensorOp);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
@@ -624,7 +620,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
struct ReturnPathState {
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -24,54 +23,28 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Pass/PIMPasses.h"
|
||||
#include "SpatialToPimPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
namespace raptor {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(Operation* op);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace raptor
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
@@ -151,8 +124,10 @@ padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector,
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
@@ -198,18 +173,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
addReturnOutputBuffers(returnOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -218,7 +191,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -267,14 +240,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||
eraseOpsToRemove();
|
||||
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||
@@ -316,7 +283,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
@@ -350,7 +317,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
@@ -394,11 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||
for (Operation* op : operationsToRemove) {
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace raptor {
|
||||
|
||||
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(mlir::Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
@@ -6,6 +8,7 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -40,7 +43,18 @@ static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
|
||||
if (!getGlobalOp)
|
||||
return false;
|
||||
|
||||
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
return globalOp && globalOp.getConstant();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
|
||||
@@ -120,6 +120,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || isConstantIndexLike(value))
|
||||
return true;
|
||||
|
||||
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
|
||||
if (extractOp) {
|
||||
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
|
||||
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
|
||||
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
|
||||
return false;
|
||||
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
|
||||
}
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
return false;
|
||||
|
||||
+1210
-1134
File diff suppressed because it is too large
Load Diff
@@ -267,212 +267,6 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
struct BatchYieldInfo {
|
||||
Value yieldedValue;
|
||||
tensor::ParallelInsertSliceOp insertSlice;
|
||||
};
|
||||
|
||||
static bool isHostOnlyBatchResultUser(Operation* user) {
|
||||
return isa<func::ReturnOp,
|
||||
spatial::SpatConcatOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::CastOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp>(user);
|
||||
}
|
||||
|
||||
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
|
||||
Block& block = batchOp.getBody().front();
|
||||
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
|
||||
if (!inParallel)
|
||||
return failure();
|
||||
|
||||
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
|
||||
for (Operation& op : inParallel.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSlice)
|
||||
return failure();
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &block)
|
||||
return failure();
|
||||
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
|
||||
}
|
||||
return batchYieldByOutputArg;
|
||||
}
|
||||
|
||||
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return failure();
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
|
||||
batchOp.getWeights(),
|
||||
batchOp.getInputs());
|
||||
newBatch.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
|
||||
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
|
||||
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
|
||||
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
|
||||
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
|
||||
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
|
||||
}
|
||||
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
return newBatch;
|
||||
}
|
||||
|
||||
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
|
||||
|
||||
for (auto batchOp : batches) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
continue;
|
||||
|
||||
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
if (!coreIdsAttr)
|
||||
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
|
||||
|
||||
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
|
||||
if (failed(batchYieldInfo))
|
||||
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
|
||||
|
||||
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
|
||||
if (failed(newBatch))
|
||||
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
|
||||
|
||||
Block& oldBlock = batchOp.getBody().front();
|
||||
Block& newBlock = newBatch->getBody().front();
|
||||
IRMapping mapper;
|
||||
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
|
||||
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
|
||||
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
|
||||
auto oldIt = oldBlock.begin();
|
||||
auto newIt = newBlock.begin();
|
||||
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
|
||||
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
|
||||
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
|
||||
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
|
||||
auto yieldInfoIt = batchYieldInfo->find(outputArg);
|
||||
if (yieldInfoIt == batchYieldInfo->end())
|
||||
return batchOp.emitOpError(
|
||||
"missing yielded value for compute_batch result during communication materialization");
|
||||
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
|
||||
|
||||
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
|
||||
SmallVector<OpOperand*> hostUses;
|
||||
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
|
||||
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
|
||||
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
|
||||
continue;
|
||||
}
|
||||
if (isHostOnlyBatchResultUser(use.getOwner())) {
|
||||
hostUses.push_back(&use);
|
||||
continue;
|
||||
}
|
||||
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
|
||||
<< ": " << use.getOwner()->getName();
|
||||
}
|
||||
|
||||
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
|
||||
if (uses.empty())
|
||||
return success();
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
channelIds.reserve(sourceCoreIds.size());
|
||||
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
|
||||
channelIds.push_back(nextChannelId++);
|
||||
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
sendChannelIdValues,
|
||||
sendSourceCoreIdValues,
|
||||
sendTargetCoreIdValues,
|
||||
mappedYieldedValue);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfter(newBatch->getOperation());
|
||||
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
|
||||
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
|
||||
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
batchOp.getLoc(),
|
||||
batchOp.getResult(resultIndex).getType(),
|
||||
receiveChannelIdValues,
|
||||
receiveSourceCoreIdValues,
|
||||
receiveTargetCoreIdValues);
|
||||
for (OpOperand* use : uses)
|
||||
use->set(received.getOutput());
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
return success();
|
||||
};
|
||||
|
||||
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
|
||||
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
|
||||
if (failed(createReceiveForUses(uses, targetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!hostUses.empty()) {
|
||||
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
|
||||
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(&newBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
@@ -731,11 +525,6 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
|
||||
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -19,7 +19,6 @@ struct ScheduledTask {
|
||||
size_t processor = std::numeric_limits<size_t>::max();
|
||||
Time startTime = 0;
|
||||
Time endTime = 0;
|
||||
size_t slot = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
@@ -243,7 +242,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft};
|
||||
scheduled[task] = true;
|
||||
++scheduledCount;
|
||||
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||
@@ -274,7 +273,65 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||
});
|
||||
|
||||
// 5. Populate Final Result
|
||||
// 5. Check if equal schedule in two level
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
|
||||
for (size_t controlProcessor = currentProcessor; controlProcessor < processorCount; ++controlProcessor) {
|
||||
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
|
||||
continue;
|
||||
auto& currentTasks = tasksByProcessor[currentProcessor];
|
||||
auto& controlTasks = tasksByProcessor[controlProcessor];
|
||||
bool equalSchedule = true;
|
||||
|
||||
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
|
||||
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
|
||||
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
|
||||
if (currentComputeInstance.op != controlComputeInstance.op
|
||||
|| currentComputeInstance.laneCount != controlComputeInstance.laneCount) {
|
||||
equalSchedule = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (equalSchedule) {
|
||||
equivalentClass[currentProcessor].push_back(controlProcessor);
|
||||
equivalentClass[controlProcessor].push_back(currentProcessor);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*{
|
||||
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||
std::vector<bool> visited(processorCount, false);
|
||||
size_t uniqueClassCount = 0;
|
||||
|
||||
for (size_t i = 0; i < processorCount; ++i) {
|
||||
if (visited[i])
|
||||
continue;
|
||||
|
||||
// We found a new unique schedule (equivalence class)
|
||||
++uniqueClassCount;
|
||||
visited[i] = true;
|
||||
|
||||
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||
|
||||
// Find and mark all identical companions
|
||||
auto it = equivalentClass.find(i);
|
||||
if (it != equivalentClass.end()) {
|
||||
for (size_t eqCpu : it->second) {
|
||||
if (!visited[eqCpu]) {
|
||||
llvm::dbgs() << ", " << eqCpu;
|
||||
visited[eqCpu] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
llvm::dbgs() << " }\n";
|
||||
}
|
||||
|
||||
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||
llvm::dbgs() << "--------------------------------------\n";
|
||||
}*/
|
||||
|
||||
// 6. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(nodeCount);
|
||||
|
||||
@@ -296,8 +353,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
}
|
||||
}
|
||||
|
||||
result.equivalentClass = equivalentClass;
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user