1240 lines
50 KiB
C++
1240 lines
50 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
#include "llvm/Support/raw_os_ostream.h"
|
|
|
|
#include <cassert>
|
|
#include <filesystem>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
#include "src/Compiler/CompilerOptions.hpp"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace pim;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
|
|
SpatialToPimPass() = default;
|
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
|
|
void runOnOperation() final;
|
|
|
|
private:
|
|
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
|
|
size_t coreId = 0;
|
|
SmallVector<Operation*> operationsToRemove;
|
|
|
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
|
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void markOpToRemove(Operation* op);
|
|
|
|
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
|
void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter);
|
|
|
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static bool isChannelUseChainOp(Operation* op) {
|
|
return isa<tensor::ExtractSliceOp,
|
|
tensor::CollapseShapeOp,
|
|
tensor::ExpandShapeOp,
|
|
tensor::CastOp,
|
|
tosa::ReshapeOp,
|
|
ONNXTransposeOp,
|
|
pim::PimTransposeOp>(op);
|
|
}
|
|
|
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
|
for (Value operand : op->getOperands()) {
|
|
if (mapping.lookupOrNull(operand))
|
|
continue;
|
|
|
|
Operation* definingOp = operand.getDefiningOp();
|
|
if (!definingOp)
|
|
continue;
|
|
|
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
|
continue;
|
|
|
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
}
|
|
|
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
|
|
|
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
|
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
|
return static_cast<int32_t>(fallbackCoreId++);
|
|
}
|
|
|
|
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
|
|
|
SmallVector<int32_t> coreIds;
|
|
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
|
for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
|
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
|
return coreIds;
|
|
}
|
|
|
|
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
|
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
|
|
|
|
rewriter.setInsertionPoint(sendOp);
|
|
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
|
rewriter.eraseOp(sendOp);
|
|
}
|
|
|
|
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
return;
|
|
}
|
|
|
|
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
|
rewriter.setInsertionPoint(receiveOp);
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
|
|
|
Value received =
|
|
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
rewriter.replaceOp(receiveOp, received);
|
|
}
|
|
|
|
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(sendManyOp);
|
|
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input);
|
|
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
|
PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr);
|
|
}
|
|
rewriter.eraseOp(sendManyOp);
|
|
}
|
|
|
|
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(receiveManyOp.getNumResults());
|
|
|
|
rewriter.setInsertionPoint(receiveManyOp);
|
|
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
|
|
auto outputType = cast<ShapedType>(output.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
|
Value received =
|
|
PimReceiveOp::create(
|
|
rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
replacements.push_back(received);
|
|
}
|
|
|
|
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
|
|
}
|
|
|
|
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
|
Value input = extractRowsOp.getInput();
|
|
RankedTensorType inputType;
|
|
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
|
|
inputType = tensorType;
|
|
}
|
|
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
|
|
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
|
|
rewriter.setInsertionPoint(extractRowsOp);
|
|
input = bufferization::ToTensorOp::create(
|
|
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
|
|
.getResult();
|
|
}
|
|
else {
|
|
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
|
|
return;
|
|
}
|
|
int64_t numCols = inputType.getDimSize(1);
|
|
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(extractRowsOp.getNumResults());
|
|
|
|
rewriter.setInsertionPoint(extractRowsOp);
|
|
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
|
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
|
if (!outputType) {
|
|
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
|
return;
|
|
}
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
auto rowSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
|
replacements.push_back(rowSlice.getResult());
|
|
}
|
|
|
|
rewriter.replaceOp(extractRowsOp, ValueRange(replacements));
|
|
}
|
|
|
|
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(concatOp);
|
|
Value concatenated =
|
|
tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult();
|
|
rewriter.replaceOp(concatOp, concatenated);
|
|
}
|
|
|
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|
SmallVectorImpl<Operation*>& helperChain,
|
|
bool requireReturnUse = true) {
|
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
|
return failure();
|
|
if (requireReturnUse
|
|
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
|
return failure();
|
|
|
|
Block& block = computeOp.getBody().front();
|
|
if (block.getNumArguments() != 1)
|
|
return failure();
|
|
|
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
|
return failure();
|
|
|
|
SmallVector<Operation*> reverseChain;
|
|
Value currentValue = yieldOp.getOperands().front();
|
|
Value blockArg = block.getArgument(0);
|
|
|
|
while (currentValue != blockArg) {
|
|
Operation* definingOp = currentValue.getDefiningOp();
|
|
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
|
return failure();
|
|
reverseChain.push_back(definingOp);
|
|
currentValue = definingOp->getOperand(0);
|
|
}
|
|
|
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
|
if (!chainSet.contains(&op)
|
|
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
|
return failure();
|
|
|
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
|
return success();
|
|
}
|
|
|
|
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
|
return false;
|
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
|
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
|
return false;
|
|
|
|
Block& block = computeOp.getBody().front();
|
|
if (block.getNumArguments() != 0)
|
|
return false;
|
|
|
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
|
return false;
|
|
|
|
rewriter.setInsertionPoint(computeOp);
|
|
IRMapping mapping;
|
|
for (Operation& op : block.without_terminator()) {
|
|
cloneMappedHelperOperands(&op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
|
|
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
|
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
|
return true;
|
|
}
|
|
|
|
struct ReturnUseInfo {
|
|
size_t returnIndex;
|
|
SmallVector<Operation*> helperChain;
|
|
};
|
|
|
|
struct ConcatReturnUseInfo {
|
|
size_t returnIndex;
|
|
SmallVector<int64_t> sliceOffsets;
|
|
SmallVector<int64_t> concatShape;
|
|
SmallVector<Operation*> helperChain;
|
|
};
|
|
|
|
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
|
|
int64_t flatIndex = 0;
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
flatIndex *= shape[i];
|
|
flatIndex += indices[i];
|
|
}
|
|
return flatIndex;
|
|
}
|
|
|
|
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
|
SmallVector<int64_t> indices(shape.size(), 0);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
|
indices[dim] = flatIndex % shape[dim];
|
|
flatIndex /= shape[dim];
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
|
auto uses = value.getUses();
|
|
if (rangeLength(uses) != 1)
|
|
return std::nullopt;
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
Value currentValue = value;
|
|
Operation* currentUser = uses.begin()->getOwner();
|
|
|
|
while (isChannelUseChainOp(currentUser)) {
|
|
helperChain.push_back(currentUser);
|
|
auto currentUses = currentUser->getResult(0).getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentValue = currentUser->getResult(0);
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
if (!isa<func::ReturnOp>(currentUser))
|
|
return std::nullopt;
|
|
|
|
return ReturnUseInfo {
|
|
currentValue.getUses().begin()->getOperandNumber(),
|
|
std::move(helperChain),
|
|
};
|
|
}
|
|
|
|
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
|
auto uses = value.getUses();
|
|
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp>(uses.begin()->getOwner()))
|
|
return std::nullopt;
|
|
|
|
auto valueType = dyn_cast<ShapedType>(value.getType());
|
|
if (!valueType || !valueType.hasStaticShape())
|
|
return std::nullopt;
|
|
|
|
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
|
|
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
|
|
Value currentValue = value;
|
|
Operation* currentUser = uses.begin()->getOwner();
|
|
|
|
while (auto concatOp = dyn_cast<tensor::ConcatOp>(currentUser)) {
|
|
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
|
int64_t axis = concatOp.getDim();
|
|
for (Value operand : concatOp.getOperands().take_front(operandIndex))
|
|
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
|
|
|
|
auto concatType = dyn_cast<ShapedType>(concatOp.getResult().getType());
|
|
if (!concatType || !concatType.hasStaticShape())
|
|
return std::nullopt;
|
|
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
|
|
|
|
currentValue = concatOp.getResult();
|
|
auto currentUses = currentValue.getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
|
return std::nullopt;
|
|
|
|
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
|
return std::nullopt;
|
|
|
|
currentValue = helperCompute.getResult(0);
|
|
auto currentUses = currentValue.getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
while (isChannelUseChainOp(currentUser)) {
|
|
helperChain.push_back(currentUser);
|
|
auto currentUses = currentUser->getResult(0).getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentValue = currentUser->getResult(0);
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
if (!isa<func::ReturnOp>(currentUser))
|
|
return std::nullopt;
|
|
|
|
return ConcatReturnUseInfo {
|
|
currentValue.getUses().begin()->getOperandNumber(),
|
|
std::move(sliceOffsets),
|
|
std::move(concatShape),
|
|
std::move(helperChain),
|
|
};
|
|
}
|
|
|
|
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
|
|
ArrayRef<int64_t> sourceShape,
|
|
ArrayRef<Operation*> helperChain,
|
|
SmallVectorImpl<int64_t>& mappedIndices) {
|
|
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
|
|
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
|
|
|
|
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
|
|
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return failure();
|
|
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
|
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
|
|
return success();
|
|
};
|
|
|
|
for (Operation* op : helperChain) {
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
|
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
};
|
|
if (!hasStaticValues(extractSliceOp.getStaticOffsets())
|
|
|| !hasStaticValues(extractSliceOp.getStaticSizes())
|
|
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
|
return failure();
|
|
|
|
SmallVector<int64_t> nextIndices;
|
|
nextIndices.reserve(currentIndices.size());
|
|
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
|
|
extractSliceOp.getStaticOffsets(),
|
|
extractSliceOp.getStaticSizes(),
|
|
extractSliceOp.getStaticStrides())) {
|
|
if (stride != 1 || index < offset || index >= offset + size)
|
|
return failure();
|
|
nextIndices.push_back(index - offset);
|
|
}
|
|
|
|
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return failure();
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
continue;
|
|
}
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
|
SmallVector<int64_t> nextShape(currentShape.size());
|
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
|
int64_t sourceIndex = attr.getInt();
|
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
|
nextShape[destIndex] = currentShape[sourceIndex];
|
|
}
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape = std::move(nextShape);
|
|
continue;
|
|
}
|
|
|
|
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
|
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
|
SmallVector<int64_t> nextShape(currentShape.size());
|
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
|
|
int64_t sourceIndex = attr.getInt();
|
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
|
nextShape[destIndex] = currentShape[sourceIndex];
|
|
}
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape = std::move(nextShape);
|
|
continue;
|
|
}
|
|
|
|
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
|
|
if (failed(reshapeToResultShape(op)))
|
|
return failure();
|
|
continue;
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
|
|
return success();
|
|
}
|
|
|
|
static void cloneHelperChain(Value sourceValue,
|
|
ArrayRef<Operation*> helperChain,
|
|
IRRewriter& rewriter,
|
|
Value& clonedValue) {
|
|
IRMapping mapping;
|
|
mapping.map(sourceValue, sourceValue);
|
|
clonedValue = sourceValue;
|
|
|
|
rewriter.setInsertionPointAfterValue(sourceValue);
|
|
for (Operation* op : helperChain) {
|
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
clonedValue = clonedOp->getResult(0);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
}
|
|
|
|
static Value emitHostCopy(IRRewriter& rewriter,
|
|
Location loc,
|
|
Value outputTensor,
|
|
Value sourceValue,
|
|
int32_t hostTargetOffset,
|
|
int32_t deviceSourceOffset,
|
|
int32_t sizeInBytes) {
|
|
return PimMemCopyDevToHostOp::create(rewriter,
|
|
loc,
|
|
outputTensor.getType(),
|
|
outputTensor,
|
|
sourceValue,
|
|
rewriter.getI32IntegerAttr(hostTargetOffset),
|
|
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
|
.getOutput();
|
|
}
|
|
|
|
void SpatialToPimPass::runOnOperation() {
|
|
coreId = 1;
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
func::FuncDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateWithGenerated(patterns);
|
|
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateGlobalTensorToMemrefPatterns(patterns);
|
|
|
|
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
|
}
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
|
|
addResultBuffer(returnOp, rewriter);
|
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
SmallVector<spatial::SpatConcatOp> concatOps;
|
|
funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); });
|
|
for (auto concatOp : concatOps)
|
|
lowerConcat(concatOp, rewriter);
|
|
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
markOpToRemove(computeOp);
|
|
runOnComputeOp(computeOp, rewriter);
|
|
}
|
|
|
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
|
markOpToRemove(computeBatchOp);
|
|
runOnComputeBatchOp(computeBatchOp, rewriter);
|
|
}
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
|
receiveOps.push_back(op);
|
|
for (auto receiveOp : receiveOps) {
|
|
bool onlyPendingRemovalUsers = llvm::all_of(
|
|
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
|
if (onlyPendingRemovalUsers) {
|
|
markOpToRemove(receiveOp);
|
|
continue;
|
|
}
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
continue;
|
|
}
|
|
lowerChannelReceive(receiveOp, rewriter);
|
|
}
|
|
|
|
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
|
|
receiveManyOps.push_back(op);
|
|
for (auto receiveManyOp : receiveManyOps)
|
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
|
sendOps.push_back(op);
|
|
for (auto sendOp : sendOps)
|
|
lowerChannelSend(sendOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
|
|
sendManyOps.push_back(op);
|
|
for (auto sendManyOp : sendManyOps)
|
|
lowerChannelSendMany(sendManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
|
extractRowsOps.push_back(op);
|
|
for (auto extractRowsOp : extractRowsOps)
|
|
lowerExtractRows(extractRowsOp, rewriter);
|
|
|
|
RewritePatternSet channelPatterns(ctx);
|
|
populateWithGenerated(channelPatterns);
|
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
|
replaceReturnOpOperands(returnOp, rewriter);
|
|
|
|
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
|
while (!pendingRemovals.empty()) {
|
|
bool erasedAnyOp = false;
|
|
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
|
Operation* opToRemove = *it;
|
|
if (!opToRemove->use_empty()) {
|
|
++it;
|
|
continue;
|
|
}
|
|
|
|
rewriter.eraseOp(opToRemove);
|
|
it = pendingRemovals.erase(it);
|
|
erasedAnyOp = true;
|
|
}
|
|
|
|
if (erasedAnyOp)
|
|
continue;
|
|
|
|
for (auto opToRemove : pendingRemovals) {
|
|
opToRemove->dump();
|
|
for (auto user : opToRemove->getUsers())
|
|
user->dump();
|
|
}
|
|
assert(false && "tracked op removal reached a cycle or missed dependency");
|
|
}
|
|
|
|
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
|
|
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
|
|
for (auto concatOp : remainingConcatOps)
|
|
lowerConcat(concatOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
|
|
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
|
|
for (auto receiveOp : remainingReceiveOps)
|
|
lowerChannelReceive(receiveOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
|
|
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
|
|
for (auto receiveManyOp : remainingReceiveManyOps)
|
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
|
|
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
|
|
for (auto sendOp : remainingSendOps)
|
|
lowerChannelSend(sendOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
|
|
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
|
|
for (auto sendManyOp : remainingSendManyOps)
|
|
lowerChannelSendMany(sendManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
|
|
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
|
|
for (auto extractRowsOp : remainingExtractRowsOps)
|
|
lowerExtractRows(extractRowsOp, rewriter);
|
|
|
|
// Dump to file for debug
|
|
bool hasSpatialOps = false;
|
|
moduleOp.walk([&](Operation* op) {
|
|
if (op->getDialect()->getNamespace() == "spat")
|
|
hasSpatialOps = true;
|
|
});
|
|
if (hasSpatialOps) {
|
|
moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "pim0");
|
|
}
|
|
|
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
|
Location loc = computeOp->getLoc();
|
|
|
|
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
|
return;
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
|
return;
|
|
|
|
auto& block = computeOp.getRegion().front();
|
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
|
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
|
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
|
if (!receiveOp || blockArg.use_empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
|
auto outputType = cast<ShapedType>(blockArg.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
|
Value received = PimReceiveOp::create(
|
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
blockArg.replaceAllUsesWith(received);
|
|
markOpToRemove(receiveOp);
|
|
}
|
|
|
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
|
|
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
|
if (result.use_empty())
|
|
continue;
|
|
|
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
|
|
|
if (auto returnUse = analyzeReturnUse(result)) {
|
|
Value storedValue = yieldValue;
|
|
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
|
for (Operation* op : returnUse->helperChain)
|
|
markOpToRemove(op);
|
|
|
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
|
if (auto storedOp = storedValue.getDefiningOp())
|
|
rewriter.setInsertionPointAfter(storedOp);
|
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
storedValue,
|
|
0,
|
|
0,
|
|
static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
auto resultUses = result.getUses();
|
|
if (rangeLength(resultUses) == 1) {
|
|
OpOperand& resultUse = *resultUses.begin();
|
|
Operation* resultUser = resultUse.getOwner();
|
|
|
|
if (isa<func::ReturnOp>(resultUser)) {
|
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
yieldValue,
|
|
0,
|
|
0,
|
|
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
if (isa<spatial::SpatChannelSendOp>(resultUser))
|
|
continue;
|
|
}
|
|
|
|
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
|
|
|
if (concatReturnUse->helperChain.empty()) {
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
yieldValue,
|
|
static_cast<int32_t>(flatOffset * elementSize),
|
|
0,
|
|
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
|
if (!storedType) {
|
|
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
|
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
|
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
|
|
|
SmallVector<int64_t> destinationIndices;
|
|
if (failed(mapIndicesThroughHelperChain(sourceIndices,
|
|
concatReturnUse->concatShape,
|
|
concatReturnUse->helperChain,
|
|
destinationIndices))) {
|
|
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
SmallVector<OpFoldResult> extractOffsets;
|
|
SmallVector<OpFoldResult> extractSizes;
|
|
SmallVector<OpFoldResult> extractStrides;
|
|
extractOffsets.reserve(storedType.getRank());
|
|
extractSizes.reserve(storedType.getRank());
|
|
extractStrides.reserve(storedType.getRank());
|
|
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
|
|
extractOffsets.push_back(rewriter.getIndexAttr(idx));
|
|
extractSizes.push_back(rewriter.getIndexAttr(1));
|
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
auto scalarTensorType =
|
|
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
|
auto elementSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
|
rewriter.setInsertionPointAfter(elementSlice);
|
|
|
|
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
|
outputTensor = emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
elementSlice.getResult(),
|
|
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
|
0,
|
|
static_cast<int32_t>(elementSize));
|
|
}
|
|
continue;
|
|
}
|
|
|
|
computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Use `HaltOp` instead of `YieldOp`
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
|
|
|
// Replace `spat.compute` with `pim.core`
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
auto coreOp = PimCoreOp::create(
|
|
rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
|
if (!blockArg.use_empty())
|
|
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
|
block.eraseArguments(0, block.getNumArguments());
|
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
|
Block* tempComputeBlock = new Block();
|
|
computeOp.getBody().push_back(tempComputeBlock);
|
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
|
}
|
|
|
|
void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) {
|
|
if (std::getenv("PIM_BATCH_LOWER_DEBUG"))
|
|
llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n";
|
|
|
|
if (computeBatchOp.getNumResults() != 0) {
|
|
computeBatchOp.emitOpError(
|
|
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
Location loc = computeBatchOp.getLoc();
|
|
Block& oldBlock = computeBatchOp.getBody().front();
|
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
if (oldYield.getNumOperands() != 0) {
|
|
computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
|
|
|
rewriter.setInsertionPointAfter(computeBatchOp);
|
|
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
|
loc,
|
|
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
|
computeBatchOp.getWeights(),
|
|
computeBatchOp.getInputs());
|
|
coreBatchOp.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(computeBatchOp.getWeights().size()), static_cast<int>(computeBatchOp.getInputs().size())});
|
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
|
|
|
SmallVector<Type> blockArgTypes;
|
|
SmallVector<Location> blockArgLocs;
|
|
for (BlockArgument arg : oldBlock.getArguments()) {
|
|
blockArgTypes.push_back(arg.getType());
|
|
blockArgLocs.push_back(arg.getLoc());
|
|
}
|
|
Block* newBlock =
|
|
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
|
|
IRMapping mapper;
|
|
rewriter.setInsertionPointToStart(newBlock);
|
|
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
|
auto newArgType = cast<ShapedType>(newArg.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
newArg,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, newArg))
|
|
.getOutput();
|
|
mapper.map(oldArg, copied);
|
|
}
|
|
|
|
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
|
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
|
return mapped;
|
|
|
|
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
capturedTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
|
.getOutput();
|
|
mapper.map(capturedTensor, copied);
|
|
return copied;
|
|
};
|
|
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
for (Operation& op : oldBlock) {
|
|
if (isa<spatial::SpatYieldOp>(op))
|
|
continue;
|
|
|
|
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
|
pim::PimSendBatchOp::create(rewriter,
|
|
loc,
|
|
mapper.lookup(sendBatchOp.getInput()),
|
|
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
|
sendBatchOp.getTargetCoreIdsAttr());
|
|
continue;
|
|
}
|
|
|
|
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
|
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
|
receiveBatchOp.getSourceCoreIdsAttr())
|
|
.getOutput();
|
|
mapper.map(receiveBatchOp.getOutput(), received);
|
|
continue;
|
|
}
|
|
|
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
|
Operation* cloned = rewriter.clone(op, mapper);
|
|
auto clonedTensor = cloned->getResult(0);
|
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
clonedTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
|
.getOutput();
|
|
mapper.map(toTensorOp.getResult(), copied);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
for (Value operand : op.getOperands()) {
|
|
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
|
continue;
|
|
|
|
Operation* definingOp = operand.getDefiningOp();
|
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
|
continue;
|
|
|
|
materializeCapturedTensor(operand);
|
|
}
|
|
|
|
Operation* cloned = rewriter.clone(op, mapper);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
|
mapper.map(originalResult, clonedResult);
|
|
}
|
|
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
PimHaltOp::create(rewriter, loc);
|
|
}
|
|
|
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return;
|
|
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
|
if (!dpsDefiningOp)
|
|
return;
|
|
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
if (!tiedOperand)
|
|
return;
|
|
Value tiedValue = tiedOperand->get();
|
|
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
|
tiedValue.setType(newType);
|
|
self(tiedValue, newType, self);
|
|
};
|
|
|
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
|
auto resultTensor = vmmOp.getOutput();
|
|
auto outShape = getTensorShape(outTensorOperand);
|
|
assert(isHVectorShape(outShape));
|
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
|
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
|
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
|
if (outTensorOperand == vmmOp.getInput()) {
|
|
rewriter.setInsertionPoint(vmmOp);
|
|
auto newOutputBuffer =
|
|
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
|
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
|
}
|
|
else {
|
|
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
|
outTensorOperand.setType(newType);
|
|
}
|
|
resultTensor.setType(newType);
|
|
|
|
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
|
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
|
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
|
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
|
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
|
rewriter.setInsertionPointAfter(vmmOp);
|
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
|
}
|
|
});
|
|
}
|
|
|
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
outputTensors.reserve(returnOp->getNumOperands());
|
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
|
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
|
|
}
|
|
else {
|
|
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
|
|
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
|
|
|
std::string outputName = "output_" + std::to_string(index);
|
|
rewriter.setInsertionPoint(returnOp.getParentOp());
|
|
memref::GlobalOp::create(rewriter,
|
|
returnOp.getLoc(),
|
|
rewriter.getStringAttr(outputName),
|
|
rewriter.getStringAttr("private"),
|
|
TypeAttr::get(memRefType),
|
|
{},
|
|
{},
|
|
{});
|
|
outputTensors.push_back(
|
|
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
return toTensor.getResult();
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
Location loc = funcOp.getLoc();
|
|
|
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
|
Type elementType = tensorType.getElementType();
|
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
|
|
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
|
|
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
loc,
|
|
tensorType,
|
|
deviceTensor,
|
|
inputTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
|
|
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
|
};
|
|
|
|
for (auto& op : funcOp.getBody().getOps())
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
|
continue;
|
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
|
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
|
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
|
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
|
insertMemCopyHostToDev(toTensorOpValue, 0);
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
|
if (!llvm::is_contained(operationsToRemove, op))
|
|
operationsToRemove.push_back(op);
|
|
}
|
|
|
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
|
if (!op)
|
|
return;
|
|
|
|
bool isExclusivelyOwnedByReturnChain = op->use_empty();
|
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
|
Operation* onlyUser = *op->getUsers().begin();
|
|
isExclusivelyOwnedByReturnChain =
|
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
|
|
}
|
|
if (!isExclusivelyOwnedByReturnChain)
|
|
return;
|
|
|
|
if (isChannelUseChainOp(op)) {
|
|
Value source = op->getOperand(0);
|
|
markOpToRemove(op);
|
|
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
markOpToRemove(computeOp);
|
|
for (Value input : computeOp.getInputs())
|
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
|
markOpToRemove(concatOp);
|
|
for (Value operand : concatOp.getOperands())
|
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
|
}
|
|
};
|
|
|
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
|
auto loc = returnOp.getLoc();
|
|
for (auto it : llvm::enumerate(originalOperands)) {
|
|
size_t orderWithinReturn = it.index();
|
|
Operation* returnOperand = it.value().getDefiningOp();
|
|
rewriter.setInsertionPoint(returnOp);
|
|
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
|
|
|
} // namespace onnx_mlir
|