1674 lines
69 KiB
C++
1674 lines
69 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Casting.h"
|
|
|
|
#include <cassert>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace pim;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
|
|
SpatialToPimPass() = default;
|
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
|
|
void runOnOperation() final;
|
|
|
|
private:
|
|
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
|
|
size_t coreId = 0;
|
|
SmallVector<Operation*> operationsToRemove;
|
|
|
|
void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
|
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void markOpToRemove(Operation* op);
|
|
|
|
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
|
void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter);
|
|
|
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static bool isChannelUseChainOp(Operation* op) {
|
|
return isa<tensor::ExtractSliceOp,
|
|
tensor::CollapseShapeOp,
|
|
tensor::ExpandShapeOp,
|
|
tensor::CastOp,
|
|
tosa::ReshapeOp,
|
|
ONNXTransposeOp,
|
|
pim::PimTransposeOp>(op);
|
|
}
|
|
|
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
|
for (Value operand : op->getOperands()) {
|
|
if (mapping.lookupOrNull(operand))
|
|
continue;
|
|
|
|
Operation* definingOp = operand.getDefiningOp();
|
|
if (!definingOp)
|
|
continue;
|
|
|
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
|
continue;
|
|
|
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
}
|
|
|
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
|
|
|
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
|
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
|
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
|
return static_cast<int32_t>(fallbackCoreId++);
|
|
}
|
|
|
|
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
|
|
|
SmallVector<int32_t> coreIds;
|
|
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
|
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
|
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
|
return coreIds;
|
|
}
|
|
|
|
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
|
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
|
|
|
|
rewriter.setInsertionPoint(sendOp);
|
|
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
|
rewriter.eraseOp(sendOp);
|
|
}
|
|
|
|
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
return;
|
|
}
|
|
|
|
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
|
rewriter.setInsertionPoint(receiveOp);
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
|
|
|
Value received =
|
|
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
rewriter.replaceOp(receiveOp, received);
|
|
}
|
|
|
|
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(sendManyOp);
|
|
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
|
|
PimSendOp::create(rewriter,
|
|
sendManyOp.getLoc(),
|
|
input,
|
|
getTensorSizeInBytesAttr(rewriter, input),
|
|
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)));
|
|
}
|
|
rewriter.eraseOp(sendManyOp);
|
|
}
|
|
|
|
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(receiveManyOp);
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(receiveManyOp.getNumResults());
|
|
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
|
|
auto outputType = cast<ShapedType>(output.getType());
|
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult();
|
|
replacements.push_back(
|
|
PimReceiveOp::create(rewriter,
|
|
receiveManyOp.getLoc(),
|
|
output.getType(),
|
|
outputBuffer,
|
|
getTensorSizeInBytesAttr(rewriter, output),
|
|
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)))
|
|
.getOutput());
|
|
}
|
|
rewriter.replaceOp(receiveManyOp, replacements);
|
|
}
|
|
|
|
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
|
|
int32_t laneCount,
|
|
IRMapping& mapper,
|
|
IRRewriter& rewriter) {
|
|
SmallVector<int32_t> targetCoreIds;
|
|
targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size());
|
|
for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds())
|
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
|
SmallVector<Value> mappedInputs;
|
|
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
|
for (Value input : sendManyBatchOp.getInputs())
|
|
mappedInputs.push_back(mapper.lookup(input));
|
|
for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) {
|
|
SmallVector<int32_t> laneTargetCoreIds;
|
|
laneTargetCoreIds.reserve(laneCount);
|
|
for (int32_t lane = 0; lane < laneCount; ++lane)
|
|
laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]);
|
|
pim::PimSendBatchOp::create(rewriter,
|
|
sendManyBatchOp.getLoc(),
|
|
input,
|
|
getTensorSizeInBytesAttr(rewriter, input),
|
|
rewriter.getDenseI32ArrayAttr(laneTargetCoreIds));
|
|
}
|
|
}
|
|
|
|
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
|
int32_t laneCount,
|
|
IRMapping& mapper,
|
|
IRRewriter& rewriter) {
|
|
SmallVector<int32_t> sourceCoreIds;
|
|
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
|
|
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
|
|
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
|
|
|
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
|
|
auto outputType = cast<ShapedType>(output.getType());
|
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult();
|
|
SmallVector<int32_t> laneSourceCoreIds;
|
|
laneSourceCoreIds.reserve(laneCount);
|
|
for (int32_t lane = 0; lane < laneCount; ++lane)
|
|
laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]);
|
|
|
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
|
receiveManyBatchOp.getLoc(),
|
|
output.getType(),
|
|
outputBuffer,
|
|
getTensorSizeInBytesAttr(rewriter, output),
|
|
rewriter.getDenseI32ArrayAttr(laneSourceCoreIds))
|
|
.getOutput();
|
|
mapper.map(output, received);
|
|
}
|
|
}
|
|
|
|
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(extractRowsOp);
|
|
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(extractRowsOp.getNumResults());
|
|
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
|
auto outputType = cast<RankedTensorType>(output.getType());
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
|
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
replacements.push_back(
|
|
tensor::ExtractSliceOp::create(
|
|
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
|
|
.getResult());
|
|
}
|
|
rewriter.replaceOp(extractRowsOp, replacements);
|
|
}
|
|
|
|
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(concatOp);
|
|
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult();
|
|
Value concatenated = pim::PimConcatOp::create(rewriter,
|
|
concatOp.getLoc(),
|
|
concatOp.getOutput().getType(),
|
|
rewriter.getI64IntegerAttr(concatOp.getAxis()),
|
|
concatOp.getInputs(),
|
|
outputBuffer)
|
|
.getOutput();
|
|
rewriter.replaceOp(concatOp, concatenated);
|
|
}
|
|
|
|
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
SmallVector<spatial::SpatMapOp> mapOps;
|
|
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
|
if (mapOp->getParentOfType<pim::PimCoreOp>() || mapOp->getParentOfType<pim::PimCoreBatchOp>())
|
|
mapOps.push_back(mapOp);
|
|
});
|
|
|
|
for (auto mapOp : mapOps) {
|
|
Block& body = mapOp.getBody().front();
|
|
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(mapOp.getInputs().size());
|
|
rewriter.setInsertionPoint(mapOp);
|
|
|
|
for (Value input : mapOp.getInputs()) {
|
|
IRMapping mapping;
|
|
mapping.map(body.getArgument(0), input);
|
|
|
|
for (Operation& bodyOp : body.without_terminator()) {
|
|
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
|
mapping.map(originalResult, clonedResult);
|
|
rewriter.setInsertionPointAfter(cloned);
|
|
}
|
|
|
|
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
|
|
}
|
|
|
|
rewriter.replaceOp(mapOp, replacements);
|
|
}
|
|
}
|
|
|
|
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
|
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
|
packedShape[0] *= count;
|
|
return RankedTensorType::get(packedShape, elementType.getElementType());
|
|
}
|
|
|
|
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
|
if (values.empty())
|
|
return false;
|
|
|
|
auto firstResult = dyn_cast<OpResult>(values.front());
|
|
if (!firstResult)
|
|
return false;
|
|
|
|
owner = firstResult.getOwner();
|
|
startIndex = firstResult.getResultNumber();
|
|
for (auto [index, value] : llvm::enumerate(values)) {
|
|
auto result = dyn_cast<OpResult>(value);
|
|
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
static Value createPackedExtractRowsSlice(
|
|
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
|
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
|
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
|
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
|
return {};
|
|
|
|
int64_t rowsPerValue = rowType.getDimSize(0);
|
|
if (ShapedType::isDynamic(rowsPerValue))
|
|
return {};
|
|
|
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides;
|
|
offsets.reserve(inputType.getRank());
|
|
sizes.reserve(inputType.getRank());
|
|
strides.reserve(inputType.getRank());
|
|
|
|
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
|
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
|
strides.push_back(rewriter.getIndexAttr(1));
|
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
|
offsets.push_back(rewriter.getIndexAttr(0));
|
|
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
|
strides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
|
.getResult();
|
|
}
|
|
|
|
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
|
Operation* owner = nullptr;
|
|
unsigned startIndex = 0;
|
|
if (!getContiguousOpResults(values, owner, startIndex))
|
|
return {};
|
|
|
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
|
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
|
|
|
return {};
|
|
}
|
|
|
|
static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp,
|
|
unsigned startIndex,
|
|
unsigned count,
|
|
IRRewriter& rewriter,
|
|
Location loc) {
|
|
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
|
|
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
|
|
return {};
|
|
|
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
|
auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType());
|
|
|
|
SmallVector<int32_t> sourceCoreIds;
|
|
sourceCoreIds.reserve(count);
|
|
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
|
|
for (unsigned index = 0; index < count; ++index)
|
|
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index]));
|
|
|
|
return pim::PimReceiveTensorOp::create(
|
|
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
|
.getOutput();
|
|
}
|
|
|
|
static Value createPackedMapTensor(
|
|
spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
|
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
|
|
if (!packedInput)
|
|
return {};
|
|
|
|
auto inputType = dyn_cast<RankedTensorType>(mapOp.getInputs()[startIndex].getType());
|
|
auto outputType = dyn_cast<RankedTensorType>(mapOp.getOutputs()[startIndex].getType());
|
|
if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape()
|
|
|| inputType.getRank() == 0 || outputType.getRank() == 0)
|
|
return {};
|
|
|
|
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(count));
|
|
auto packedInit =
|
|
tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType());
|
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
auto upper = arith::ConstantIndexOp::create(rewriter, loc, count);
|
|
auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()});
|
|
|
|
{
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
Block* loopBlock = loop.getBody();
|
|
rewriter.setInsertionPointToStart(loopBlock);
|
|
Value iv = loopBlock->getArgument(0);
|
|
Value acc = loopBlock->getArgument(1);
|
|
|
|
int64_t inputRowsPerValue = inputType.getDimSize(0);
|
|
Value inputRowOffset = iv;
|
|
if (inputRowsPerValue != 1) {
|
|
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue);
|
|
inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> extractOffsets;
|
|
SmallVector<OpFoldResult> extractSizes;
|
|
SmallVector<OpFoldResult> extractStrides;
|
|
extractOffsets.push_back(inputRowOffset);
|
|
extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue));
|
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
|
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
|
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
auto inputSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
|
|
|
IRMapping mapping;
|
|
Block& body = mapOp.getBody().front();
|
|
mapping.map(body.getArgument(0), inputSlice.getResult());
|
|
for (Operation& bodyOp : body.without_terminator()) {
|
|
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
|
mapping.map(originalResult, clonedResult);
|
|
rewriter.setInsertionPointAfter(cloned);
|
|
}
|
|
|
|
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
|
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
|
|
|
int64_t outputRowsPerValue = outputType.getDimSize(0);
|
|
Value outputRowOffset = iv;
|
|
if (outputRowsPerValue != 1) {
|
|
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue);
|
|
outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> insertOffsets;
|
|
SmallVector<OpFoldResult> insertSizes;
|
|
SmallVector<OpFoldResult> insertStrides;
|
|
insertOffsets.push_back(outputRowOffset);
|
|
insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue));
|
|
insertStrides.push_back(rewriter.getIndexAttr(1));
|
|
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
|
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
|
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
|
insertStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
auto inserted =
|
|
tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
|
scf::YieldOp::create(rewriter, loc, inserted.getResult());
|
|
}
|
|
|
|
return loop.getResult(0);
|
|
}
|
|
|
|
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
|
funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
|
for (auto sendManyOp : sendManyOps) {
|
|
if (sendManyOp.getInputs().empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(sendManyOp);
|
|
Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc());
|
|
if (!packedInput)
|
|
continue;
|
|
|
|
SmallVector<int32_t> targetCoreIds;
|
|
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
|
|
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
|
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
|
pim::PimSendTensorOp::create(
|
|
rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
rewriter.eraseOp(sendManyOp);
|
|
}
|
|
|
|
SmallVector<spatial::SpatConcatOp> concatOps;
|
|
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
|
for (auto concatOp : concatOps) {
|
|
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
|
continue;
|
|
|
|
SmallVector<Value> packedInputs;
|
|
bool changed = false;
|
|
rewriter.setInsertionPoint(concatOp);
|
|
|
|
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
|
Value input = concatOp.getInputs()[index];
|
|
auto result = dyn_cast<OpResult>(input);
|
|
if (!result) {
|
|
packedInputs.push_back(input);
|
|
++index;
|
|
continue;
|
|
}
|
|
|
|
Operation* owner = result.getOwner();
|
|
unsigned startIndex = result.getResultNumber();
|
|
unsigned endIndex = index + 1;
|
|
while (endIndex < concatOp.getInputs().size()) {
|
|
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
|
if (!nextResult || nextResult.getOwner() != owner
|
|
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
|
break;
|
|
++endIndex;
|
|
}
|
|
|
|
unsigned count = endIndex - index;
|
|
Value packedInput;
|
|
if (auto mapOp = dyn_cast<spatial::SpatMapOp>(owner))
|
|
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
|
|
else if (auto receiveManyOp = dyn_cast<spatial::SpatChannelReceiveManyOp>(owner))
|
|
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
|
|
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
|
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
|
|
|
if (packedInput) {
|
|
packedInputs.push_back(packedInput);
|
|
changed = true;
|
|
}
|
|
else {
|
|
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
|
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
|
}
|
|
|
|
index = endIndex;
|
|
}
|
|
|
|
if (!changed)
|
|
continue;
|
|
|
|
auto newConcat = pim::PimConcatOp::create(
|
|
rewriter,
|
|
concatOp.getLoc(),
|
|
concatOp.getOutput().getType(),
|
|
concatOp.getAxisAttr(),
|
|
ValueRange(packedInputs),
|
|
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
|
|
.getResult());
|
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
|
}
|
|
|
|
auto eraseUnusedOps = [&](auto tag) {
|
|
using OpTy = decltype(tag);
|
|
SmallVector<OpTy> ops;
|
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
|
for (auto op : llvm::reverse(ops))
|
|
if (op->use_empty())
|
|
rewriter.eraseOp(op);
|
|
};
|
|
eraseUnusedOps(spatial::SpatMapOp {});
|
|
eraseUnusedOps(spatial::SpatChannelReceiveManyOp {});
|
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
|
}
|
|
|
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|
SmallVectorImpl<Operation*>& helperChain,
|
|
bool requireReturnUse = true) {
|
|
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
|
return failure();
|
|
if (requireReturnUse
|
|
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
|
return failure();
|
|
|
|
Block& block = computeOp.getBody().front();
|
|
if (block.getNumArguments() != 1)
|
|
return failure();
|
|
|
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
|
return failure();
|
|
|
|
SmallVector<Operation*> reverseChain;
|
|
Value currentValue = yieldOp.getOperands().front();
|
|
Value blockArg = block.getArgument(0);
|
|
|
|
while (currentValue != blockArg) {
|
|
Operation* definingOp = currentValue.getDefiningOp();
|
|
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
|
return failure();
|
|
reverseChain.push_back(definingOp);
|
|
currentValue = definingOp->getOperand(0);
|
|
}
|
|
|
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
|
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
|
return failure();
|
|
|
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
|
return success();
|
|
}
|
|
|
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
|
return false;
|
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
|
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
|
}))
|
|
return false;
|
|
|
|
Block& block = computeOp.getBody().front();
|
|
if (block.getNumArguments() != 0)
|
|
return false;
|
|
|
|
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
|
return false;
|
|
|
|
rewriter.setInsertionPoint(computeOp);
|
|
IRMapping mapping;
|
|
for (Operation& op : block.without_terminator()) {
|
|
cloneMappedHelperOperands(&op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
|
|
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
|
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
|
return true;
|
|
}
|
|
|
|
struct ReturnUseInfo {
|
|
size_t returnIndex;
|
|
SmallVector<Operation*> helperChain;
|
|
};
|
|
|
|
struct ConcatReturnUseInfo {
|
|
size_t returnIndex;
|
|
SmallVector<int64_t> sliceOffsets;
|
|
SmallVector<int64_t> concatShape;
|
|
SmallVector<Operation*> concatChain;
|
|
SmallVector<Operation*> helperChain;
|
|
};
|
|
|
|
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
|
|
int64_t flatIndex = 0;
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
flatIndex *= shape[i];
|
|
flatIndex += indices[i];
|
|
}
|
|
return flatIndex;
|
|
}
|
|
|
|
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
|
SmallVector<int64_t> indices(shape.size(), 0);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
|
indices[dim] = flatIndex % shape[dim];
|
|
flatIndex /= shape[dim];
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
|
auto uses = value.getUses();
|
|
if (rangeLength(uses) != 1)
|
|
return std::nullopt;
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
Value currentValue = value;
|
|
Operation* currentUser = uses.begin()->getOwner();
|
|
|
|
while (isChannelUseChainOp(currentUser)) {
|
|
helperChain.push_back(currentUser);
|
|
auto currentUses = currentUser->getResult(0).getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentValue = currentUser->getResult(0);
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
if (!isa<func::ReturnOp>(currentUser))
|
|
return std::nullopt;
|
|
|
|
return ReturnUseInfo {
|
|
currentValue.getUses().begin()->getOperandNumber(),
|
|
std::move(helperChain),
|
|
};
|
|
}
|
|
|
|
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
|
auto getConcatResult = [](Operation* op) -> Value {
|
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
|
return tensorConcat.getResult();
|
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
|
return spatialConcat.getOutput();
|
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
|
return pimConcat.getOutput();
|
|
return {};
|
|
};
|
|
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
|
return tensorConcat.getDim();
|
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
|
return spatialConcat.getAxis();
|
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
|
return pimConcat.getAxis();
|
|
return std::nullopt;
|
|
};
|
|
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
|
return tensorConcat.getOperands();
|
|
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
|
return spatialConcat.getInputs();
|
|
return cast<pim::PimConcatOp>(op).getInputs();
|
|
};
|
|
|
|
auto uses = value.getUses();
|
|
if (rangeLength(uses) != 1
|
|
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
|
return std::nullopt;
|
|
|
|
auto valueType = dyn_cast<ShapedType>(value.getType());
|
|
if (!valueType || !valueType.hasStaticShape())
|
|
return std::nullopt;
|
|
|
|
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
|
|
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
|
|
SmallVector<Operation*> concatChain;
|
|
Value currentValue = value;
|
|
Operation* currentUser = uses.begin()->getOwner();
|
|
|
|
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
|
|
concatChain.push_back(currentUser);
|
|
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
|
int64_t axis = *getConcatAxis(currentUser);
|
|
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
|
|
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
|
|
|
|
Value concatResult = getConcatResult(currentUser);
|
|
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
|
|
if (!concatType || !concatType.hasStaticShape())
|
|
return std::nullopt;
|
|
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
|
|
|
|
currentValue = concatResult;
|
|
auto currentUses = currentValue.getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
|
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
|
return std::nullopt;
|
|
|
|
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
|
return std::nullopt;
|
|
|
|
currentValue = helperCompute.getResult(0);
|
|
auto currentUses = currentValue.getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
while (isChannelUseChainOp(currentUser)) {
|
|
helperChain.push_back(currentUser);
|
|
auto currentUses = currentUser->getResult(0).getUses();
|
|
if (rangeLength(currentUses) != 1)
|
|
return std::nullopt;
|
|
currentValue = currentUser->getResult(0);
|
|
currentUser = currentUses.begin()->getOwner();
|
|
}
|
|
|
|
if (!isa<func::ReturnOp>(currentUser))
|
|
return std::nullopt;
|
|
|
|
return ConcatReturnUseInfo {
|
|
currentValue.getUses().begin()->getOperandNumber(),
|
|
std::move(sliceOffsets),
|
|
std::move(concatShape),
|
|
std::move(concatChain),
|
|
std::move(helperChain),
|
|
};
|
|
}
|
|
|
|
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
|
|
ArrayRef<int64_t> sourceShape,
|
|
ArrayRef<Operation*> helperChain,
|
|
SmallVectorImpl<int64_t>& mappedIndices) {
|
|
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
|
|
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
|
|
|
|
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
|
|
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return failure();
|
|
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
|
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
|
|
return success();
|
|
};
|
|
|
|
for (Operation* op : helperChain) {
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
|
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
};
|
|
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
|
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
|
return failure();
|
|
|
|
SmallVector<int64_t> nextIndices;
|
|
nextIndices.reserve(currentIndices.size());
|
|
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
|
|
extractSliceOp.getStaticOffsets(),
|
|
extractSliceOp.getStaticSizes(),
|
|
extractSliceOp.getStaticStrides())) {
|
|
if (stride != 1 || index < offset || index >= offset + size)
|
|
return failure();
|
|
nextIndices.push_back(index - offset);
|
|
}
|
|
|
|
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
|
|
if (!resultType || !resultType.hasStaticShape())
|
|
return failure();
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
continue;
|
|
}
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
|
SmallVector<int64_t> nextShape(currentShape.size());
|
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
|
int64_t sourceIndex = attr.getInt();
|
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
|
nextShape[destIndex] = currentShape[sourceIndex];
|
|
}
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape = std::move(nextShape);
|
|
continue;
|
|
}
|
|
|
|
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
|
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
|
SmallVector<int64_t> nextShape(currentShape.size());
|
|
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
|
|
int64_t sourceIndex = attr.getInt();
|
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
|
nextShape[destIndex] = currentShape[sourceIndex];
|
|
}
|
|
currentIndices = std::move(nextIndices);
|
|
currentShape = std::move(nextShape);
|
|
continue;
|
|
}
|
|
|
|
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
|
|
if (failed(reshapeToResultShape(op)))
|
|
return failure();
|
|
continue;
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
|
|
return success();
|
|
}
|
|
|
|
static void
|
|
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
|
IRMapping mapping;
|
|
mapping.map(sourceValue, sourceValue);
|
|
clonedValue = sourceValue;
|
|
|
|
rewriter.setInsertionPointAfterValue(sourceValue);
|
|
for (Operation* op : helperChain) {
|
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
|
mapping.map(originalResult, newResult);
|
|
clonedValue = clonedOp->getResult(0);
|
|
rewriter.setInsertionPointAfter(clonedOp);
|
|
}
|
|
}
|
|
|
|
static Value emitHostCopy(IRRewriter& rewriter,
|
|
Location loc,
|
|
Value outputTensor,
|
|
Value sourceValue,
|
|
int32_t hostTargetOffset,
|
|
int32_t deviceSourceOffset,
|
|
int32_t sizeInBytes) {
|
|
return PimMemCopyDevToHostOp::create(rewriter,
|
|
loc,
|
|
outputTensor.getType(),
|
|
outputTensor,
|
|
sourceValue,
|
|
rewriter.getI32IntegerAttr(hostTargetOffset),
|
|
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
|
rewriter.getI32IntegerAttr(sizeInBytes))
|
|
.getOutput();
|
|
}
|
|
|
|
void SpatialToPimPass::runOnOperation() {
|
|
coreId = 1;
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
func::FuncDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateWithGenerated(patterns);
|
|
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateGlobalTensorToMemrefPatterns(patterns);
|
|
|
|
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
|
}
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
|
|
addResultBuffer(returnOp, rewriter);
|
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
markOpToRemove(computeOp);
|
|
runOnComputeOp(computeOp, rewriter);
|
|
}
|
|
|
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
|
markOpToRemove(computeBatchOp);
|
|
runOnComputeBatchOp(computeBatchOp, rewriter);
|
|
}
|
|
|
|
compactSpatialTensorGroups(funcOp, rewriter);
|
|
lowerMapOps(funcOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
|
receiveOps.push_back(op);
|
|
for (auto receiveOp : receiveOps) {
|
|
bool onlyPendingRemovalUsers = llvm::all_of(
|
|
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
|
if (onlyPendingRemovalUsers) {
|
|
markOpToRemove(receiveOp);
|
|
continue;
|
|
}
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
continue;
|
|
}
|
|
lowerChannelReceive(receiveOp, rewriter);
|
|
}
|
|
|
|
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
|
|
receiveManyOps.push_back(op);
|
|
for (auto receiveManyOp : receiveManyOps)
|
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
|
sendOps.push_back(op);
|
|
for (auto sendOp : sendOps)
|
|
lowerChannelSend(sendOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
|
|
sendManyOps.push_back(op);
|
|
for (auto sendManyOp : sendManyOps)
|
|
lowerChannelSendMany(sendManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
|
extractRowsOps.push_back(op);
|
|
for (auto extractRowsOp : extractRowsOps)
|
|
lowerExtractRows(extractRowsOp, rewriter);
|
|
|
|
{
|
|
RewritePatternSet coreBodyPatterns(ctx);
|
|
populateWithGenerated(coreBodyPatterns);
|
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
|
|
|
SmallVector<pim::PimCoreOp> coreOps;
|
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
|
for (auto coreOp : coreOps) {
|
|
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
|
for (auto coreBatchOp : coreBatchOps) {
|
|
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
RewritePatternSet channelPatterns(ctx);
|
|
populateWithGenerated(channelPatterns);
|
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
|
replaceReturnOpOperands(returnOp, rewriter);
|
|
|
|
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
|
while (!pendingRemovals.empty()) {
|
|
bool erasedAnyOp = false;
|
|
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
|
|
Operation* opToRemove = *it;
|
|
if (!opToRemove->use_empty()) {
|
|
++it;
|
|
continue;
|
|
}
|
|
|
|
rewriter.eraseOp(opToRemove);
|
|
it = pendingRemovals.erase(it);
|
|
erasedAnyOp = true;
|
|
}
|
|
|
|
if (erasedAnyOp)
|
|
continue;
|
|
|
|
for (auto opToRemove : pendingRemovals) {
|
|
opToRemove->dump();
|
|
for (auto user : opToRemove->getUsers())
|
|
user->dump();
|
|
}
|
|
assert(false && "tracked op removal reached a cycle or missed dependency");
|
|
}
|
|
|
|
compactSpatialTensorGroups(funcOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
|
|
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
|
|
for (auto concatOp : remainingConcatOps)
|
|
lowerConcat(concatOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
|
|
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
|
|
for (auto receiveOp : remainingReceiveOps)
|
|
lowerChannelReceive(receiveOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
|
|
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
|
|
for (auto receiveManyOp : remainingReceiveManyOps)
|
|
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
|
|
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
|
|
for (auto sendOp : remainingSendOps)
|
|
lowerChannelSend(sendOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
|
|
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
|
|
for (auto sendManyOp : remainingSendManyOps)
|
|
lowerChannelSendMany(sendManyOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
|
|
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
|
|
for (auto extractRowsOp : remainingExtractRowsOps)
|
|
lowerExtractRows(extractRowsOp, rewriter);
|
|
|
|
// Dump to file for debug
|
|
bool hasSpatialOps = false;
|
|
moduleOp.walk([&](Operation* op) {
|
|
if (op->getDialect()->getNamespace() == "spat")
|
|
hasSpatialOps = true;
|
|
});
|
|
if (hasSpatialOps) {
|
|
moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "pim0");
|
|
}
|
|
|
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
|
Location loc = computeOp->getLoc();
|
|
|
|
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
|
return;
|
|
|
|
SmallVector<Operation*> helperChain;
|
|
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
|
return;
|
|
|
|
auto& block = computeOp.getRegion().front();
|
|
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
|
|
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
|
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
|
if (!receiveOp || blockArg.use_empty())
|
|
continue;
|
|
|
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
|
auto outputType = cast<ShapedType>(blockArg.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
|
Value received = PimReceiveOp::create(
|
|
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
blockArg.replaceAllUsesWith(received);
|
|
markOpToRemove(receiveOp);
|
|
}
|
|
|
|
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
|
|
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
|
if (result.use_empty())
|
|
continue;
|
|
|
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
|
|
|
if (auto returnUse = analyzeReturnUse(result)) {
|
|
Value storedValue = yieldValue;
|
|
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
|
for (Operation* op : returnUse->helperChain)
|
|
markOpToRemove(op);
|
|
|
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
|
if (auto storedOp = storedValue.getDefiningOp())
|
|
rewriter.setInsertionPointAfter(storedOp);
|
|
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
storedValue,
|
|
0,
|
|
0,
|
|
static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
auto resultUses = result.getUses();
|
|
if (rangeLength(resultUses) == 1) {
|
|
OpOperand& resultUse = *resultUses.begin();
|
|
Operation* resultUser = resultUse.getOwner();
|
|
|
|
if (isa<func::ReturnOp>(resultUser)) {
|
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
yieldValue,
|
|
0,
|
|
0,
|
|
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
if (isa<spatial::SpatChannelSendOp>(resultUser))
|
|
continue;
|
|
}
|
|
|
|
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
|
for (Operation* concatOp : concatReturnUse->concatChain)
|
|
markOpToRemove(concatOp);
|
|
|
|
if (concatReturnUse->helperChain.empty()) {
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
|
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
|
emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
yieldValue,
|
|
static_cast<int32_t>(flatOffset * elementSize),
|
|
0,
|
|
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
|
continue;
|
|
}
|
|
|
|
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
|
if (!storedType) {
|
|
computeOp.emitOpError(
|
|
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
|
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
|
auto outputType = cast<ShapedType>(outputTensor.getType());
|
|
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
|
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
|
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
|
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
|
|
|
SmallVector<int64_t> destinationIndices;
|
|
if (failed(mapIndicesThroughHelperChain(
|
|
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
|
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
SmallVector<OpFoldResult> extractOffsets;
|
|
SmallVector<OpFoldResult> extractSizes;
|
|
SmallVector<OpFoldResult> extractStrides;
|
|
extractOffsets.reserve(storedType.getRank());
|
|
extractSizes.reserve(storedType.getRank());
|
|
extractStrides.reserve(storedType.getRank());
|
|
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
|
|
extractOffsets.push_back(rewriter.getIndexAttr(idx));
|
|
extractSizes.push_back(rewriter.getIndexAttr(1));
|
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
auto scalarTensorType =
|
|
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
|
auto elementSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
|
rewriter.setInsertionPointAfter(elementSlice);
|
|
|
|
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
|
outputTensor = emitHostCopy(rewriter,
|
|
loc,
|
|
outputTensor,
|
|
elementSlice.getResult(),
|
|
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
|
0,
|
|
static_cast<int32_t>(elementSize));
|
|
}
|
|
continue;
|
|
}
|
|
|
|
computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Use `HaltOp` instead of `YieldOp`
|
|
rewriter.setInsertionPoint(yieldOp);
|
|
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
|
|
|
// Replace `spat.compute` with `pim.core`
|
|
SmallVector<Value> computeWeights;
|
|
if (!computeOp.getWeights().empty())
|
|
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
auto coreOp = PimCoreOp::create(
|
|
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
|
if (!blockArg.use_empty())
|
|
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
|
block.eraseArguments(0, block.getNumArguments());
|
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
|
Block* tempComputeBlock = new Block();
|
|
computeOp.getBody().push_back(tempComputeBlock);
|
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
|
}
|
|
|
|
void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) {
|
|
if (std::getenv("PIM_BATCH_LOWER_DEBUG"))
|
|
llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n";
|
|
|
|
if (computeBatchOp.getNumResults() != 0) {
|
|
computeBatchOp.emitOpError(
|
|
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
Location loc = computeBatchOp.getLoc();
|
|
Block& oldBlock = computeBatchOp.getBody().front();
|
|
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
|
if (oldYield.getNumOperands() != 0) {
|
|
computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
|
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
|
SmallVector<Value> batchInputs;
|
|
if (!computeBatchOp.getInputs().empty())
|
|
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
|
|
|
rewriter.setInsertionPointAfter(computeBatchOp);
|
|
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
|
loc,
|
|
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
|
ValueRange(batchWeights),
|
|
ValueRange(batchInputs));
|
|
coreBatchOp.getProperties().setOperandSegmentSizes(
|
|
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
|
|
|
SmallVector<Type> blockArgTypes;
|
|
SmallVector<Location> blockArgLocs;
|
|
for (BlockArgument arg : oldBlock.getArguments()) {
|
|
blockArgTypes.push_back(arg.getType());
|
|
blockArgLocs.push_back(arg.getLoc());
|
|
}
|
|
Block* newBlock =
|
|
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
|
|
IRMapping mapper;
|
|
rewriter.setInsertionPointToStart(newBlock);
|
|
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
|
auto newArgType = cast<ShapedType>(newArg.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
newArg,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, newArg))
|
|
.getOutput();
|
|
mapper.map(oldArg, copied);
|
|
}
|
|
|
|
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
|
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
|
return mapped;
|
|
|
|
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
capturedTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
|
.getOutput();
|
|
mapper.map(capturedTensor, copied);
|
|
return copied;
|
|
};
|
|
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
for (Operation& op : oldBlock) {
|
|
if (isa<spatial::SpatYieldOp>(op))
|
|
continue;
|
|
|
|
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
|
pim::PimSendBatchOp::create(rewriter,
|
|
loc,
|
|
mapper.lookup(sendBatchOp.getInput()),
|
|
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
|
sendBatchOp.getTargetCoreIdsAttr());
|
|
continue;
|
|
}
|
|
|
|
if (auto sendManyBatchOp = dyn_cast<spatial::SpatChannelSendManyBatchOp>(op)) {
|
|
lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
|
continue;
|
|
}
|
|
|
|
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
|
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
|
receiveBatchOp.getSourceCoreIdsAttr())
|
|
.getOutput();
|
|
mapper.map(receiveBatchOp.getOutput(), received);
|
|
continue;
|
|
}
|
|
|
|
if (auto receiveManyBatchOp = dyn_cast<spatial::SpatChannelReceiveManyBatchOp>(op)) {
|
|
lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
|
continue;
|
|
}
|
|
|
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
|
Operation* cloned = rewriter.clone(op, mapper);
|
|
auto clonedTensor = cloned->getResult(0);
|
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
|
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
loc,
|
|
outputBuffer.getType(),
|
|
outputBuffer,
|
|
clonedTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(0),
|
|
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
|
.getOutput();
|
|
mapper.map(toTensorOp.getResult(), copied);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
for (Value operand : op.getOperands()) {
|
|
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
|
continue;
|
|
|
|
Operation* definingOp = operand.getDefiningOp();
|
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
|
continue;
|
|
|
|
materializeCapturedTensor(operand);
|
|
}
|
|
|
|
Operation* cloned = rewriter.clone(op, mapper);
|
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
|
mapper.map(originalResult, clonedResult);
|
|
}
|
|
|
|
rewriter.setInsertionPointToEnd(newBlock);
|
|
PimHaltOp::create(rewriter, loc);
|
|
}
|
|
|
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return;
|
|
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
|
if (!dpsDefiningOp)
|
|
return;
|
|
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
if (!tiedOperand)
|
|
return;
|
|
Value tiedValue = tiedOperand->get();
|
|
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
|
tiedValue.setType(newType);
|
|
self(tiedValue, newType, self);
|
|
};
|
|
|
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
|
auto resultTensor = vmmOp.getOutput();
|
|
auto outShape = getTensorShape(outTensorOperand);
|
|
assert(isHVectorShape(outShape));
|
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
|
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
|
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
|
if (outTensorOperand == vmmOp.getInput()) {
|
|
rewriter.setInsertionPoint(vmmOp);
|
|
auto newOutputBuffer =
|
|
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
|
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
|
}
|
|
else {
|
|
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
|
outTensorOperand.setType(newType);
|
|
}
|
|
resultTensor.setType(newType);
|
|
|
|
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
|
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
|
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
|
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
|
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
|
rewriter.setInsertionPointAfter(vmmOp);
|
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
|
}
|
|
});
|
|
}
|
|
|
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
outputTensors.reserve(returnOp->getNumOperands());
|
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
|
Value currentReturnValue = returnValue;
|
|
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
|
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
|
outputTensors.push_back(
|
|
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
|
|
}
|
|
else {
|
|
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(currentReturnValue.getType());
|
|
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
|
|
|
std::string outputName = "output_" + std::to_string(index);
|
|
rewriter.setInsertionPoint(returnOp.getParentOp());
|
|
memref::GlobalOp::create(rewriter,
|
|
returnOp.getLoc(),
|
|
rewriter.getStringAttr(outputName),
|
|
rewriter.getStringAttr("private"),
|
|
TypeAttr::get(memRefType),
|
|
{},
|
|
{},
|
|
{});
|
|
outputTensors.push_back(
|
|
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
|
auto toTensor = bufferization::ToTensorOp::create(
|
|
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
return toTensor.getResult();
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
Location loc = funcOp.getLoc();
|
|
|
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
|
Type elementType = tensorType.getElementType();
|
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
|
|
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
|
|
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
loc,
|
|
tensorType,
|
|
deviceTensor,
|
|
inputTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
|
|
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
|
};
|
|
|
|
for (auto& op : funcOp.getBody().getOps())
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
|
continue;
|
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
|
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
|
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
|
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
|
insertMemCopyHostToDev(toTensorOpValue, 0);
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
|
if (!llvm::is_contained(operationsToRemove, op))
|
|
operationsToRemove.push_back(op);
|
|
}
|
|
|
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
|
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
|
if (!op)
|
|
return;
|
|
|
|
bool isExclusivelyOwnedByReturnChain = op->use_empty();
|
|
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
|
Operation* onlyUser = *op->getUsers().begin();
|
|
isExclusivelyOwnedByReturnChain =
|
|
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
|
|| isChannelUseChainOp(onlyUser);
|
|
}
|
|
if (!isExclusivelyOwnedByReturnChain)
|
|
return;
|
|
|
|
if (isChannelUseChainOp(op)) {
|
|
Value source = op->getOperand(0);
|
|
markOpToRemove(op);
|
|
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
markOpToRemove(computeOp);
|
|
if (!computeOp.getInputs().empty())
|
|
for (Value input : computeOp.getInputs())
|
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
|
markOpToRemove(concatOp);
|
|
for (Value operand : concatOp.getOperands())
|
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
|
markOpToRemove(concatOp);
|
|
for (Value operand : concatOp.getInputs())
|
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
|
return;
|
|
}
|
|
|
|
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
|
markOpToRemove(concatOp);
|
|
for (Value operand : concatOp.getInputs())
|
|
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
|
}
|
|
};
|
|
|
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
|
auto loc = returnOp.getLoc();
|
|
for (auto it : llvm::enumerate(originalOperands)) {
|
|
size_t orderWithinReturn = it.index();
|
|
Operation* returnOperand = it.value().getDefiningOp();
|
|
rewriter.setInsertionPoint(returnOp);
|
|
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
|
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
|
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
|
|
|
} // namespace onnx_mlir
|