Files
Raptor/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp
2026-05-04 13:42:43 +02:00

1240 lines
50 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_os_ostream.h"
#include <cassert>
#include <filesystem>
#include <optional>
#include <string>
#include <utility>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPimPass() = default;
SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final;
private:
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
};
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
return static_cast<int32_t>(fallbackCoreId++);
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
rewriter.setInsertionPoint(sendOp);
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
rewriter.eraseOp(sendOp);
}
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
return;
}
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
rewriter.setInsertionPoint(receiveOp);
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, received);
}
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(sendManyOp);
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input);
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr);
}
rewriter.eraseOp(sendManyOp);
}
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
SmallVector<Value> replacements;
replacements.reserve(receiveManyOp.getNumResults());
rewriter.setInsertionPoint(receiveManyOp);
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId));
Value received =
PimReceiveOp::create(
rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
replacements.push_back(received);
}
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
inputType = tensorType;
}
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
rewriter.setInsertionPoint(extractRowsOp);
input = bufferization::ToTensorOp::create(
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
.getResult();
}
else {
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
return;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
rewriter.setInsertionPoint(extractRowsOp);
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
return;
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto rowSlice = tensor::ExtractSliceOp::create(
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
replacements.push_back(rowSlice.getResult());
}
rewriter.replaceOp(extractRowsOp, ValueRange(replacements));
}
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(concatOp);
Value concatenated =
tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult();
rewriter.replaceOp(concatOp, concatenated);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op)
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
struct ReturnUseInfo {
size_t returnIndex;
SmallVector<Operation*> helperChain;
};
struct ConcatReturnUseInfo {
size_t returnIndex;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> concatShape;
SmallVector<Operation*> helperChain;
};
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
int64_t flatIndex = 0;
for (size_t i = 0; i < shape.size(); ++i) {
flatIndex *= shape[i];
flatIndex += indices[i];
}
return flatIndex;
}
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
auto uses = value.getUses();
if (rangeLength(uses) != 1)
return std::nullopt;
SmallVector<Operation*> helperChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isChannelUseChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(helperChain),
};
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto uses = value.getUses();
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return std::nullopt;
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (auto concatOp = dyn_cast<tensor::ConcatOp>(currentUser)) {
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = concatOp.getDim();
for (Value operand : concatOp.getOperands().take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
auto concatType = dyn_cast<ShapedType>(concatOp.getResult().getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatOp.getResult();
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
return std::nullopt;
currentValue = helperCompute.getResult(0);
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
while (isChannelUseChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ConcatReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(sliceOffsets),
std::move(concatShape),
std::move(helperChain),
};
}
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
ArrayRef<int64_t> sourceShape,
ArrayRef<Operation*> helperChain,
SmallVectorImpl<int64_t>& mappedIndices) {
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
return success();
};
for (Operation* op : helperChain) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(extractSliceOp.getStaticOffsets())
|| !hasStaticValues(extractSliceOp.getStaticSizes())
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
return failure();
SmallVector<int64_t> nextIndices;
nextIndices.reserve(currentIndices.size());
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
if (stride != 1 || index < offset || index >= offset + size)
return failure();
nextIndices.push_back(index - offset);
}
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
currentIndices = std::move(nextIndices);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
if (failed(reshapeToResultShape(op)))
return failure();
continue;
}
return failure();
}
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
return success();
}
static void cloneHelperChain(Value sourceValue,
ArrayRef<Operation*> helperChain,
IRRewriter& rewriter,
Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static Value emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
func::FuncDialect,
scf::SCFDialect,
BuiltinDialect>();
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorToMemrefPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addResultBuffer(returnOp, rewriter);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
signalPassFailure();
return;
}
SmallVector<spatial::SpatConcatOp> concatOps;
funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); });
for (auto concatOp : concatOps)
lowerConcat(concatOp, rewriter);
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
runOnComputeOp(computeOp, rewriter);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
runOnComputeBatchOp(computeBatchOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
for (auto receiveOp : receiveOps) {
bool onlyPendingRemovalUsers = llvm::all_of(
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
if (onlyPendingRemovalUsers) {
markOpToRemove(receiveOp);
continue;
}
if (receiveOp->use_empty()) {
rewriter.eraseOp(receiveOp);
continue;
}
lowerChannelReceive(receiveOp, rewriter);
}
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
receiveManyOps.push_back(op);
for (auto receiveManyOp : receiveManyOps)
lowerChannelReceiveMany(receiveManyOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> sendOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
sendOps.push_back(op);
for (auto sendOp : sendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
sendManyOps.push_back(op);
for (auto sendManyOp : sendManyOps)
lowerChannelSendMany(sendManyOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
extractRowsOps.push_back(op);
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
signalPassFailure();
return;
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
while (!pendingRemovals.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingRemovals.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (auto opToRemove : pendingRemovals) {
opToRemove->dump();
for (auto user : opToRemove->getUsers())
user->dump();
}
assert(false && "tracked op removal reached a cycle or missed dependency");
}
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
for (auto concatOp : remainingConcatOps)
lowerConcat(concatOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
for (auto receiveOp : remainingReceiveOps)
lowerChannelReceive(receiveOp, rewriter);
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
for (auto receiveManyOp : remainingReceiveManyOps)
lowerChannelReceiveMany(receiveManyOp, rewriter);
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
for (auto sendOp : remainingSendOps)
lowerChannelSend(sendOp, rewriter);
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
for (auto sendManyOp : remainingSendManyOps)
lowerChannelSendMany(sendManyOp, rewriter);
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
for (auto extractRowsOp : remainingExtractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
// Dump to file for debug
bool hasSpatialOps = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() == "spat")
hasSpatialOps = true;
});
if (hasSpatialOps) {
moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module");
signalPassFailure();
return;
}
// Dump to file for debug
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
return;
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return;
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
if (result.use_empty())
continue;
auto yieldType = cast<TensorType>(yieldValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
storedValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize));
continue;
}
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
0,
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
continue;
}
if (isa<spatial::SpatChannelSendOp>(resultUser))
continue;
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
continue;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(sourceIndices,
concatReturnUse->concatShape,
concatReturnUse->helperChain,
destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.reserve(storedType.getRank());
extractSizes.reserve(storedType.getRank());
extractStrides.reserve(storedType.getRank());
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
extractOffsets.push_back(rewriter.getIndexAttr(idx));
extractSizes.push_back(rewriter.getIndexAttr(1));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
outputTensor = emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
}
continue;
}
computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
// Use `HaltOp` instead of `YieldOp`
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
// Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(
rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
}
void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) {
if (std::getenv("PIM_BATCH_LOWER_DEBUG"))
llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n";
if (computeBatchOp.getNumResults() != 0) {
computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
signalPassFailure();
return;
}
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0) {
computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
signalPassFailure();
return;
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
loc,
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
computeBatchOp.getWeights(),
computeBatchOp.getInputs());
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(computeBatchOp.getWeights().size()), static_cast<int>(computeBatchOp.getInputs().size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
Block* newBlock =
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
}
Operation* cloned = rewriter.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
rewriter.setInsertionPointToEnd(newBlock);
PimHaltOp::create(rewriter, loc);
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return;
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
if (!dpsDefiningOp)
return;
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
if (!tiedOperand)
return;
Value tiedValue = tiedOperand->get();
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
tiedValue.setType(newType);
self(tiedValue, newType, self);
};
funcOp.walk([&](PimVMMOp vmmOp) {
auto outTensorOperand = vmmOp.getOutputBuffer();
auto resultTensor = vmmOp.getOutput();
auto outShape = getTensorShape(outTensorOperand);
assert(isHVectorShape(outShape));
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
if (outTensorOperand == vmmOp.getInput()) {
rewriter.setInsertionPoint(vmmOp);
auto newOutputBuffer =
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
}
else {
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
}
resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
}
});
}
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
}
else {
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputName = "output_" + std::to_string(index);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back(
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
}
}
}
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc,
tensorType,
deviceTensor,
inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
};
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
continue;
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
}
}
}
return success();
}
void SpatialToPimPass::markOpToRemove(Operation* op) {
if (!llvm::is_contained(operationsToRemove, op))
operationsToRemove.push_back(op);
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
bool isExclusivelyOwnedByReturnChain = op->use_empty();
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
if (isChannelUseChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(computeOp);
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir